mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	 deee214a95
			
		
	
	deee214a95
	
	
	
		
			
			* initial commit with workong optmimizer * update ACKNOWLEDGMENTS.md * nits and adding it to test * nits * G.astype(mx.bfloat16) to G.astype(G.dtype) * G.ndim >= 2 to assert G.ndim == 2 * remove coments * replace with mx.addmm * remove comments * format * nits * match muon * fix addmm --------- Co-authored-by: Awni Hannun <awni@apple.com>
		
			
				
	
	
		
			585 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			585 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright © 2023 Apple Inc.
 | |
| 
 | |
| import inspect
 | |
| import math
 | |
| import unittest
 | |
| from functools import partial
 | |
| 
 | |
| import mlx.core as mx
 | |
| import mlx.nn as nn
 | |
| import mlx.optimizers as opt
 | |
| import mlx.utils
 | |
| import mlx_tests
 | |
| import numpy as np
 | |
| from mlx.utils import tree_flatten, tree_map, tree_unflatten
 | |
| 
 | |
| try:
 | |
|     import torch
 | |
|     import torch.nn.functional as F
 | |
| 
 | |
|     has_torch = True
 | |
| except ImportError as e:
 | |
|     has_torch = False
 | |
| 
 | |
| 
 | |
| def get_all_optimizers():
 | |
|     classes = dict()
 | |
|     for name, obj in inspect.getmembers(opt):
 | |
|         if (
 | |
|             inspect.isclass(obj)
 | |
|             and issubclass(obj, opt.Optimizer)
 | |
|             and obj != opt.Optimizer
 | |
|         ):
 | |
|             classes[name] = obj
 | |
|     return classes
 | |
| 
 | |
| 
 | |
| def tree_equal(fn, *args):
 | |
|     return all(v for _, v in tree_flatten(tree_map(fn, *args)))
 | |
| 
 | |
| 
 | |
| optimizers_dict = get_all_optimizers()
 | |
| del optimizers_dict["MultiOptimizer"]
 | |
| 
 | |
| 
 | |
| class TestOptimizers(mlx_tests.MLXTestCase):
 | |
|     def test_optimizer_state(self):
 | |
|         optim = opt.SGD(0.1)
 | |
|         optim.state["hello"] = "world"
 | |
|         self.assertEqual(optim.state["hello"], "world")
 | |
| 
 | |
|         optim.state = {0: 1}
 | |
|         self.assertEqual(optim.state, {0: 1})
 | |
| 
 | |
|     def test_optimizers(self):
 | |
|         params = {
 | |
|             "first": [mx.zeros((10,)), mx.zeros((1,))],
 | |
|             "second": mx.zeros((1,)),
 | |
|         }
 | |
|         grads = tree_map(lambda x: mx.ones_like(x), params)
 | |
| 
 | |
|         for optim_class in optimizers_dict.values():
 | |
|             optim = optim_class(0.1)
 | |
|             update = optim.apply_gradients(grads, params)
 | |
|             mx.eval(update)
 | |
|             equal_shape = tree_map(lambda x, y: x.shape == y.shape, params, update)
 | |
|             all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape))
 | |
|             self.assertTrue(all_equal)
 | |
| 
 | |
|     def test_types_conserved(self):
 | |
|         params = {"w": mx.ones((5, 5), mx.float16)}
 | |
|         grads = tree_map(lambda x: mx.ones_like(x), params)
 | |
|         for optim_class in optimizers_dict.values():
 | |
|             optim = optim_class(0.1)
 | |
|             update = optim.apply_gradients(grads, params)
 | |
|             self.assertEqual(update["w"].dtype, mx.float16)
 | |
| 
 | |
|     def test_sgd(self):
 | |
|         params = {
 | |
|             "first": [mx.zeros((10,)), mx.zeros((1,))],
 | |
|             "second": mx.zeros((1,)),
 | |
|         }
 | |
|         grads = tree_map(lambda x: mx.ones_like(x), params)
 | |
| 
 | |
|         # Explicit init
 | |
|         optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
 | |
|         optim.init(params)
 | |
|         self.assertTrue(
 | |
|             tree_equal(
 | |
|                 lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
 | |
|                 params,
 | |
|                 optim.state,
 | |
|             )
 | |
|         )
 | |
| 
 | |
|         # Implicit init
 | |
|         optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
 | |
|         optim.apply_gradients(grads, params)
 | |
|         self.assertTrue(
 | |
|             tree_equal(lambda g, s: mx.array_equal(s["v"], g), grads, optim.state)
 | |
|         )
 | |
| 
 | |
|     def test_rmsprop(self):
 | |
|         params = {
 | |
|             "first": [mx.zeros((10,)), mx.zeros((1,))],
 | |
|             "second": mx.zeros((1,)),
 | |
|         }
 | |
|         grads = tree_map(lambda x: mx.ones_like(x), params)
 | |
| 
 | |
|         # Explicit init
 | |
|         optim = opt.RMSprop(learning_rate=1e-2)
 | |
|         optim.init(params)
 | |
|         self.assertTrue(
 | |
|             tree_equal(
 | |
|                 lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
 | |
|                 params,
 | |
|                 optim.state,
 | |
|             )
 | |
|         )
 | |
| 
 | |
|         # Implicit init
 | |
|         alpha = 0.99
 | |
|         optim = opt.RMSprop(learning_rate=1e-2, alpha=alpha)
 | |
|         optim.apply_gradients(grads, params)
 | |
|         self.assertTrue(
 | |
|             tree_equal(
 | |
|                 lambda g, s: mx.allclose(s["v"], (1 - alpha) * g), grads, optim.state
 | |
|             )
 | |
|         )
 | |
| 
 | |
|     def test_adagrad(self):
 | |
|         params = {
 | |
|             "first": [mx.zeros((10,)), mx.zeros((1,))],
 | |
|             "second": mx.zeros((1,)),
 | |
|         }
 | |
|         grads = tree_map(lambda x: mx.ones_like(x), params)
 | |
| 
 | |
|         # Explicit init
 | |
|         optim = opt.Adagrad(learning_rate=1e-2)
 | |
|         optim.init(params)
 | |
|         self.assertTrue(
 | |
|             tree_equal(
 | |
|                 lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
 | |
|                 params,
 | |
|                 optim.state,
 | |
|             )
 | |
|         )
 | |
| 
 | |
|     def test_adadelta(self):
 | |
|         params = {
 | |
|             "first": [mx.zeros((10,)), mx.zeros((1,))],
 | |
|             "second": mx.zeros((1,)),
 | |
|         }
 | |
|         grads = tree_map(lambda x: mx.ones_like(x), params)
 | |
| 
 | |
|         # Explicit init
 | |
|         optim = opt.AdaDelta(learning_rate=1e-2)
 | |
|         optim.init(params)
 | |
|         self.assertTrue(
 | |
|             tree_equal(
 | |
|                 lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
 | |
|                 params,
 | |
|                 optim.state,
 | |
|             )
 | |
|         )
 | |
|         self.assertTrue(
 | |
|             tree_equal(
 | |
|                 lambda p, s: mx.array_equal(s["u"], mx.zeros_like(p)),
 | |
|                 params,
 | |
|                 optim.state,
 | |
|             )
 | |
|         )
 | |
| 
 | |
|     def test_adam(self):
 | |
|         params = {
 | |
|             "first": [mx.zeros((10,)), mx.zeros((1,))],
 | |
|             "second": mx.zeros((1,)),
 | |
|         }
 | |
|         grads = tree_map(lambda x: mx.ones_like(x), params)
 | |
| 
 | |
|         # Explicit init
 | |
|         for optimizer in [opt.Adam, opt.AdamW, opt.Adamax]:
 | |
|             optim = optimizer(learning_rate=1e-2)
 | |
|             optim.init(params)
 | |
|             self.assertTrue(
 | |
|                 tree_equal(
 | |
|                     lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
 | |
|                     params,
 | |
|                     optim.state,
 | |
|                 )
 | |
|             )
 | |
|             self.assertTrue(
 | |
|                 tree_equal(
 | |
|                     lambda p, s: mx.array_equal(s["m"], mx.zeros_like(p)),
 | |
|                     params,
 | |
|                     optim.state,
 | |
|                 )
 | |
|             )
 | |
| 
 | |
|         # Test for correct gradient type propagation
 | |
|         params = tree_map(lambda x: x.astype(mx.float16), params)
 | |
|         grads = tree_map(lambda x: x.astype(mx.float16), grads)
 | |
|         optim = opt.Adam(1e-2, bias_correction=True)
 | |
|         new_params = optim.apply_gradients(grads, params)
 | |
|         self.assertTrue(tree_equal(lambda p: p.dtype == mx.float16, new_params))
 | |
| 
 | |
|     @unittest.skipIf(not has_torch, "requires Torch")
 | |
|     def test_adamw_matches_pytorch(self):
 | |
|         mx.random.seed(0)
 | |
|         np.random.seed(0)
 | |
| 
 | |
|         model = nn.Linear(3, 1)
 | |
|         init_weight = np.array(model.weight.tolist())
 | |
|         init_bias = np.array(model.bias.tolist())
 | |
| 
 | |
|         def loss_fn(model, x, y):
 | |
|             pred = model(x)
 | |
|             return nn.losses.mse_loss(pred, y)
 | |
| 
 | |
|         x = np.random.rand(3, 3)
 | |
|         y = np.random.rand(3, 1)
 | |
| 
 | |
|         optimizer = opt.AdamW(learning_rate=3e-4, bias_correction=True)
 | |
|         loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
 | |
|         loss, grads = loss_and_grad_fn(model, mx.array(x), mx.array(y))
 | |
|         optimizer.update(model, grads)
 | |
| 
 | |
|         # Equivalent torch code
 | |
|         torch_model = torch.nn.Linear(3, 1)
 | |
| 
 | |
|         # copy over the parameters
 | |
|         torch_model.weight.data = torch.tensor(init_weight, dtype=torch.float32)
 | |
|         torch_model.bias.data = torch.tensor(init_bias, dtype=torch.float32)
 | |
| 
 | |
|         torch_optimizer = torch.optim.AdamW(torch_model.parameters(), lr=3e-4)
 | |
|         torch_optimizer.zero_grad()
 | |
|         pred = torch_model(torch.tensor(x, dtype=torch.float32))
 | |
|         loss = torch.nn.MSELoss()(pred, torch.tensor(y, dtype=torch.float32))
 | |
|         loss.backward()
 | |
|         torch_optimizer.step()
 | |
| 
 | |
|         for name, param in torch_model.named_parameters():
 | |
|             mlx_grad = np.array(grads[name])
 | |
|             torch_grad = param.grad.detach().numpy()
 | |
|             self.assertTrue(np.allclose(torch_grad, mlx_grad))
 | |
| 
 | |
|         for name, param in torch_model.named_parameters():
 | |
|             mlx_param = np.array(model[name])
 | |
|             torch_param = param.data.detach().numpy()
 | |
|             self.assertTrue(np.allclose(torch_param, mlx_param))
 | |
| 
 | |
|     def test_lion(self):
 | |
|         params = {
 | |
|             "first": [mx.zeros((10,)), mx.zeros((1,))],
 | |
|             "second": mx.zeros((1,)),
 | |
|         }
 | |
|         grads = tree_map(lambda x: mx.ones_like(x), params)
 | |
| 
 | |
|         # Explicit init
 | |
|         optim = opt.Lion(learning_rate=1e-2)
 | |
|         optim.init(params)
 | |
|         self.assertTrue(
 | |
|             tree_equal(
 | |
|                 lambda p, s: mx.array_equal(s["m"], mx.zeros_like(p)),
 | |
|                 params,
 | |
|                 optim.state,
 | |
|             )
 | |
|         )
 | |
| 
 | |
|     def test_adafactor(self):
 | |
|         x = mx.zeros((5, 5))
 | |
|         params = {"x": x}
 | |
|         grad = {"x": mx.ones_like(x)}
 | |
|         optimizer = opt.Adafactor()
 | |
|         for _ in range(2):
 | |
|             xp = optimizer.apply_gradients(grad, params)
 | |
|             self.assertEqual(xp["x"].dtype, x.dtype)
 | |
|             self.assertEqual(xp["x"].shape, x.shape)
 | |
| 
 | |
|         x = mx.zeros((5, 5), mx.float16)
 | |
|         params = {"x": x}
 | |
|         grad = {"x": mx.ones_like(x)}
 | |
|         optimizer = opt.Adafactor()
 | |
|         for _ in range(2):
 | |
|             xp = optimizer.apply_gradients(grad, params)
 | |
|             self.assertEqual(xp["x"].dtype, x.dtype)
 | |
|             self.assertEqual(xp["x"].shape, x.shape)
 | |
|         self.assertEqual(optimizer.state["step"], 2)
 | |
| 
 | |
|     def test_muon(self):
 | |
|         params = {
 | |
|             "first": [mx.zeros((10, 5)), mx.zeros((1,))],
 | |
|             "second": mx.zeros((3, 3)),
 | |
|             "conv": mx.zeros((16, 8, 3, 3)),
 | |
|         }
 | |
|         grads = tree_map(lambda x: mx.ones_like(x), params)
 | |
| 
 | |
|         # Explicit init
 | |
|         optim = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=True)
 | |
|         optim.init(params)
 | |
|         self.assertTrue(
 | |
|             tree_equal(
 | |
|                 lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
 | |
|                 params,
 | |
|                 optim.state,
 | |
|             )
 | |
|         )
 | |
| 
 | |
|         # Test update
 | |
|         updated_params = optim.apply_gradients(grads, params)
 | |
| 
 | |
|         # Check that shapes are preserved
 | |
|         self.assertTrue(
 | |
|             tree_equal(
 | |
|                 lambda p, u: p.shape == u.shape,
 | |
|                 params,
 | |
|                 updated_params,
 | |
|             )
 | |
|         )
 | |
| 
 | |
|         # Check that parameters actually changed
 | |
|         self.assertFalse(
 | |
|             tree_equal(
 | |
|                 lambda p, u: mx.array_equal(p, u),
 | |
|                 params,
 | |
|                 updated_params,
 | |
|             )
 | |
|         )
 | |
| 
 | |
|         # Test with different configurations
 | |
|         optim_no_nesterov = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=False)
 | |
|         optim_no_nesterov.apply_gradients(grads, params)
 | |
| 
 | |
|         optim_no_momentum = opt.Muon(learning_rate=1e-2, momentum=0.0)
 | |
|         optim_no_momentum.apply_gradients(grads, params)
 | |
| 
 | |
|     def test_compiled_optimizer(self):
 | |
|         model = nn.Linear(10, 10)
 | |
|         x = mx.random.uniform(shape=(2, 10))
 | |
|         optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
 | |
| 
 | |
|         orig_params = model.parameters()
 | |
| 
 | |
|         def loss(model, x):
 | |
|             return model(x).sum()
 | |
| 
 | |
|         # Uncompiled version
 | |
|         def step(x):
 | |
|             _, grad = nn.value_and_grad(model, loss)(model, x)
 | |
|             optim.update(model, grad)
 | |
| 
 | |
|         step(x)
 | |
|         uncompiled_params = model.parameters()
 | |
| 
 | |
|         # Pure version
 | |
|         def loss(params, x):
 | |
|             model.update(params)
 | |
|             return model(x).sum()
 | |
| 
 | |
|         model.update(orig_params)
 | |
|         optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
 | |
| 
 | |
|         @mx.compile
 | |
|         def step(params, opt_state, x):
 | |
|             grad = mx.grad(loss)(params, x)
 | |
|             optim.state = opt_state
 | |
|             params = optim.apply_gradients(grad, params)
 | |
|             return params, optim.state
 | |
| 
 | |
|         optim.init(model.parameters())
 | |
|         pure_params, _ = step(model.parameters(), optim.state, x)
 | |
|         self.assertTrue(mx.allclose(pure_params["weight"], uncompiled_params["weight"]))
 | |
|         self.assertTrue(mx.allclose(pure_params["bias"], uncompiled_params["bias"]))
 | |
| 
 | |
|         # Impure version
 | |
|         def loss(model, x):
 | |
|             return model(x).sum()
 | |
| 
 | |
|         model.update(orig_params)
 | |
|         optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
 | |
|         state = [model.state, optim.state]
 | |
| 
 | |
|         @partial(mx.compile, inputs=state, outputs=state)
 | |
|         def step(x):
 | |
|             _, grad = nn.value_and_grad(model, loss)(model, x)
 | |
|             optim.update(model, grad)
 | |
| 
 | |
|         step(x)
 | |
|         impure_params = model.parameters()
 | |
|         self.assertTrue(
 | |
|             mx.allclose(impure_params["weight"], uncompiled_params["weight"])
 | |
|         )
 | |
|         self.assertTrue(mx.allclose(impure_params["bias"], uncompiled_params["bias"]))
 | |
| 
 | |
|     def test_update_lr_compiled(self):
 | |
|         params = {"w": mx.ones((5, 5))}
 | |
|         grads = tree_map(lambda x: mx.ones_like(x), params)
 | |
|         optim = opt.SGD(-1.0)
 | |
| 
 | |
|         @partial(mx.compile, inputs=optim.state)
 | |
|         def update(grads):
 | |
|             return optim.apply_gradients(grads, params)
 | |
| 
 | |
|         result = update(grads)
 | |
|         self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 2.0)))
 | |
|         optim.learning_rate = -2.0
 | |
|         result = update(grads)
 | |
|         self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0)))
 | |
| 
 | |
| 
 | |
| class TestSchedulers(mlx_tests.MLXTestCase):
 | |
|     def test_decay_lr(self):
 | |
|         for optim_class in optimizers_dict.values():
 | |
|             lr_schedule = opt.step_decay(1e-1, 0.9, 1)
 | |
|             optimizer = optim_class(learning_rate=lr_schedule)
 | |
| 
 | |
|             params = {"w": mx.ones((5, 5))}
 | |
|             grads = tree_map(lambda x: mx.ones_like(x), params)
 | |
| 
 | |
|             for it in range(10):
 | |
|                 optimizer.apply_gradients(grads, params)
 | |
|                 expected_lr = 0.1 * (0.9**it)
 | |
|                 self.assertAlmostEqual(optimizer.learning_rate, expected_lr, delta=1e-7)
 | |
| 
 | |
|     def test_step_decay(self):
 | |
|         lr_schedule = opt.step_decay(1e-1, 0.9, 1000)
 | |
|         lr = lr_schedule(2500)
 | |
|         expected_lr = 0.1 * (0.9**2)
 | |
|         self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
 | |
| 
 | |
|     def test_exponential_decay(self):
 | |
|         lr_schedule = opt.exponential_decay(1e-1, 0.99)
 | |
|         lr = lr_schedule(10)
 | |
|         expected_lr = 0.1 * (0.99**10)
 | |
|         self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
 | |
| 
 | |
|     def test_cosine_decay(self):
 | |
|         lr_schedule = opt.cosine_decay(0.1, 10)
 | |
|         lr = lr_schedule(4)
 | |
|         expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10))
 | |
|         self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
 | |
| 
 | |
|         lr_schedule = opt.cosine_decay(0.1, 10, 0.05)
 | |
|         lr = lr_schedule(9)
 | |
|         expected_end_lr = 0.05
 | |
|         self.assertGreater(lr, expected_end_lr)
 | |
|         lr = lr_schedule(20)
 | |
|         self.assertEqual(lr, expected_end_lr)
 | |
| 
 | |
|     def test_schedule_joiner(self):
 | |
|         boundaries = [2, 3, 4]
 | |
|         schedules = [lambda _: 3, lambda _: 4, lambda _: 5]
 | |
|         with self.assertRaises(ValueError):
 | |
|             opt.schedulers.join_schedules(schedules, boundaries)
 | |
|         boundaries = [2, 4]
 | |
|         schedule = opt.schedulers.join_schedules(schedules, boundaries)
 | |
|         self.assertEqual(schedule(0).item(), 3)
 | |
|         self.assertEqual(schedule(1).item(), 3)
 | |
|         self.assertEqual(schedule(2).item(), 4)
 | |
|         self.assertEqual(schedule(3).item(), 4)
 | |
|         self.assertEqual(schedule(5).item(), 5)
 | |
|         self.assertEqual(schedule(7).item(), 5)
 | |
| 
 | |
|     def test_linear_warmup_with_cosine_decay(self):
 | |
|         warmup_schedule = opt.schedulers.linear_schedule(0.0, 1e-5, 100)
 | |
|         cosine_schedule = opt.schedulers.cosine_decay(1e-5, 100)
 | |
|         cos_with_warmup = opt.schedulers.join_schedules(
 | |
|             [warmup_schedule, cosine_schedule], [101]
 | |
|         )
 | |
|         self.assertEqual(cos_with_warmup(0), 0.0)
 | |
|         self.assertAlmostEqual(cos_with_warmup(101), 1e-5, delta=1e-1)
 | |
|         optimizer = opt.Adam(learning_rate=cos_with_warmup)
 | |
|         for _ in range(100):
 | |
|             optimizer.update({}, {})
 | |
|         self.assertAlmostEqual(optimizer.learning_rate.item(), 1e-5, delta=1e-1)
 | |
|         for _ in range(100):
 | |
|             optimizer.update({}, {})
 | |
|         expected_lr = 1e-5 * 0.5 * (1.0 + math.cos(math.pi * 200 / 10))
 | |
|         self.assertAlmostEqual(optimizer.learning_rate.item(), expected_lr, delta=1e-1)
 | |
| 
 | |
|     def test_compile_with_schedule(self):
 | |
|         lr_schedule = opt.exponential_decay(1e-1, 0.9)
 | |
|         optimizer = opt.SGD(learning_rate=lr_schedule)
 | |
| 
 | |
|         @partial(mx.compile, inputs=optimizer.state, outputs=optimizer.state)
 | |
|         def update():
 | |
|             optimizer.update({}, {})
 | |
| 
 | |
|         for step in range(5):
 | |
|             update()
 | |
|             self.assertAlmostEqual(lr_schedule(step), optimizer.learning_rate.item())
 | |
| 
 | |
|     def test_clip_grad_norm(self):
 | |
|         # Test with small gradients that do not require clipping
 | |
|         small_grads = {
 | |
|             "first": [mx.array([0.1, 0.2]), mx.array([0.1])],
 | |
|             "second": mx.array([0.3]),
 | |
|         }
 | |
|         max_norm = 10.0  # A large max_norm that shouldn't trigger clipping
 | |
|         clipped_grads, total_norm = opt.clip_grad_norm(small_grads, max_norm)
 | |
|         self.assertTrue(
 | |
|             tree_equal(lambda x, y: mx.array_equal(x, y), small_grads, clipped_grads),
 | |
|             "Gradients should not be modified when clipping is not necessary.",
 | |
|         )
 | |
| 
 | |
|         # Test with large gradients that require clipping
 | |
|         large_grads = {
 | |
|             "first": [mx.array([10, 20]), mx.array([10])],
 | |
|             "second": mx.array([30]),
 | |
|         }
 | |
|         max_norm = 1.0  # A small max_norm that should trigger clipping
 | |
|         clipped_grads, total_norm = opt.clip_grad_norm(large_grads, max_norm)
 | |
|         # Correctly extract only the gradient values for norm calculation
 | |
|         clipped_values = [value for _, value in tree_flatten(clipped_grads)]
 | |
|         norm_of_clipped = mx.sqrt(
 | |
|             sum(mx.square(g).sum() for g in clipped_values)
 | |
|         ).item()
 | |
|         self.assertAlmostEqual(
 | |
|             norm_of_clipped,
 | |
|             max_norm,
 | |
|             places=6,
 | |
|             msg="Clipped gradients norm should be close to the specified max_norm.",
 | |
|         )
 | |
| 
 | |
|         # Ensures that the scaling was done correctly
 | |
|         scale = max_norm / total_norm
 | |
|         expected_grads = tree_map(lambda g: g * scale, large_grads)
 | |
|         self.assertTrue(
 | |
|             tree_equal(
 | |
|                 lambda x, y: mx.allclose(x, y, atol=1e-6), expected_grads, clipped_grads
 | |
|             ),
 | |
|             "Gradients were not scaled correctly during clipping.",
 | |
|         )
 | |
| 
 | |
|     def test_init_from_state(self):
 | |
|         class Model(nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.l1 = nn.Linear(2, 2)
 | |
|                 self.drop = nn.Dropout(p=0.5)
 | |
|                 self.l2 = nn.Linear(2, 2)
 | |
|                 self.vals = [nn.Linear(2, 2), nn.ReLU(), nn.ReLU()]
 | |
| 
 | |
|         model = Model()
 | |
|         optimizer = opt.Adam(learning_rate=3e-4)
 | |
|         optimizer.init(model.trainable_parameters())
 | |
| 
 | |
|         # Flatten the state for serialization
 | |
|         state = tree_flatten(optimizer.state)
 | |
| 
 | |
|         # Make a new optimizer and load the state
 | |
|         optimizer = opt.Adam(learning_rate=3e-4)
 | |
|         optimizer.state = tree_unflatten(state)
 | |
| 
 | |
|         # This should work without any errors
 | |
|         grads = model.trainable_parameters()
 | |
|         optimizer.update(model, grads)
 | |
| 
 | |
|     def test_multi_optimizer(self):
 | |
|         class Model(nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.l1 = nn.Linear(2, 2)
 | |
|                 self.drop = nn.Dropout(p=0.5)
 | |
|                 self.l2 = nn.Linear(2, 2)
 | |
|                 self.vals = [nn.Linear(2, 2), nn.ReLU(), nn.ReLU()]
 | |
| 
 | |
|         model = Model()
 | |
|         optimizer = opt.MultiOptimizer(
 | |
|             [opt.Adam(learning_rate=0.001), opt.SGD(learning_rate=0.1)],
 | |
|             [lambda name, weight: weight.ndim > 1],
 | |
|         )
 | |
|         optimizer.init(model.trainable_parameters())
 | |
| 
 | |
|         self.assertEqual(len(optimizer.state["states"]), 2)
 | |
| 
 | |
|         adam_states = tree_flatten(optimizer.state["states"][0])
 | |
|         sgd_states = tree_flatten(optimizer.state["states"][1])
 | |
|         self.assertEqual((len(sgd_states) - 2) * 2, len(adam_states) - 2)
 | |
|         self.assertFalse(any("bias" in k for k, v in adam_states))
 | |
|         self.assertFalse(any("weight" in k for k, v in sgd_states))
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     mlx_tests.MLXTestRunner()
 |