mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Add a multi optimizer (#1916)
This commit is contained in:
		 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							a0737273d3
						
					
				
				
					commit
					9680f72cca
				
			| @@ -39,6 +39,7 @@ def tree_equal(fn, *args): | ||||
|  | ||||
|  | ||||
| optimizers_dict = get_all_optimizers() | ||||
| del optimizers_dict["MultiOptimizer"] | ||||
|  | ||||
|  | ||||
| class TestOptimizers(mlx_tests.MLXTestCase): | ||||
| @@ -500,6 +501,30 @@ class TestSchedulers(unittest.TestCase): | ||||
|         grads = model.trainable_parameters() | ||||
|         optimizer.update(model, grads) | ||||
|  | ||||
|     def test_multi_optimizer(self): | ||||
|         class Model(nn.Module): | ||||
|             def __init__(self): | ||||
|                 super().__init__() | ||||
|                 self.l1 = nn.Linear(2, 2) | ||||
|                 self.drop = nn.Dropout(p=0.5) | ||||
|                 self.l2 = nn.Linear(2, 2) | ||||
|                 self.vals = [nn.Linear(2, 2), nn.ReLU(), nn.ReLU()] | ||||
|  | ||||
|         model = Model() | ||||
|         optimizer = opt.MultiOptimizer( | ||||
|             [opt.Adam(learning_rate=0.001), opt.SGD(learning_rate=0.1)], | ||||
|             [lambda name, weight: weight.ndim > 1], | ||||
|         ) | ||||
|         optimizer.init(model.trainable_parameters()) | ||||
|  | ||||
|         self.assertEqual(len(optimizer.state["states"]), 2) | ||||
|  | ||||
|         adam_states = tree_flatten(optimizer.state["states"][0]) | ||||
|         sgd_states = tree_flatten(optimizer.state["states"][1]) | ||||
|         self.assertEqual((len(sgd_states) - 2) * 2, len(adam_states) - 2) | ||||
|         self.assertFalse(any("bias" in k for k, v in adam_states)) | ||||
|         self.assertFalse(any("weight" in k for k, v in sgd_states)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
| @@ -3,6 +3,7 @@ | ||||
| import unittest | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx.nn as nn | ||||
| import mlx.utils | ||||
| import mlx_tests | ||||
|  | ||||
| @@ -22,6 +23,29 @@ class TestTreeUtils(mlx_tests.MLXTestCase): | ||||
|         self.assertEqual(list(zip(*flat_tree))[1], vals) | ||||
|         self.assertEqual(mlx.utils.tree_unflatten(flat_tree), tree) | ||||
|  | ||||
|     def test_merge(self): | ||||
|         t1 = {"a": 0} | ||||
|         t2 = {"b": 1} | ||||
|         t = mlx.utils.tree_merge(t1, t2) | ||||
|         self.assertEqual({"a": 0, "b": 1}, t) | ||||
|         with self.assertRaises(ValueError): | ||||
|             mlx.utils.tree_merge(t1, t1) | ||||
|         with self.assertRaises(ValueError): | ||||
|             mlx.utils.tree_merge(t, t1) | ||||
|  | ||||
|         mod1 = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) | ||||
|         mod2 = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) | ||||
|         mod = nn.Sequential(mod1, mod2) | ||||
|  | ||||
|         params1 = {"layers": [mod1.parameters()]} | ||||
|         params2 = {"layers": [None, mod2.parameters()]} | ||||
|         params = mlx.utils.tree_merge(params1, params2) | ||||
|         for (k1, v1), (k2, v2) in zip( | ||||
|             mlx.utils.tree_flatten(params), mlx.utils.tree_flatten(mod.parameters()) | ||||
|         ): | ||||
|             self.assertEqual(k1, k2) | ||||
|             self.assertTrue(mx.array_equal(v1, v2)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
| @@ -345,6 +345,7 @@ class TestVmap(mlx_tests.MLXTestCase): | ||||
|             ) | ||||
|  | ||||
|     def test_vmap_inverse(self): | ||||
|         mx.random.seed(42) | ||||
|         a = mx.random.uniform(shape=(3, 4, 4)) | ||||
|  | ||||
|         cpu_inv = lambda x: mx.linalg.inv(x, stream=mx.cpu) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user