mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Make it easier to test new optimizers implemented: no need to change test file manually (#90)
* add helper function get_all_optimizers() in test_optimizers.py * remove unused import
This commit is contained in:
		| @@ -1,12 +1,22 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
|  | ||||
| import unittest | ||||
| import inspect | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx.optimizers as opt | ||||
| import mlx.utils | ||||
| import mlx_tests | ||||
|  | ||||
| def get_all_optimizers(): | ||||
|     classes = dict() | ||||
|     for name, obj in inspect.getmembers(opt): | ||||
|         if inspect.isclass(obj): | ||||
|             if obj.__name__ not in ["OptimizerState", "Optimizer"]: | ||||
|                 classes[name] = obj | ||||
|     return classes | ||||
|  | ||||
| optimizers_dict = get_all_optimizers() | ||||
|  | ||||
| class TestOptimizers(mlx_tests.MLXTestCase): | ||||
|     def test_optimizers(self): | ||||
| @@ -16,7 +26,8 @@ class TestOptimizers(mlx_tests.MLXTestCase): | ||||
|         } | ||||
|         grads = mlx.utils.tree_map(lambda x: mx.ones_like(x), params) | ||||
|  | ||||
|         for optim in [opt.SGD(0.1), opt.Adam(0.1)]: | ||||
|         for optim_class in optimizers_dict.values(): | ||||
|             optim = optim_class(0.1) | ||||
|             update = optim.apply_gradients(grads, params) | ||||
|             mx.eval(update) | ||||
|             equal_shape = mlx.utils.tree_map( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 ShiJZ
					ShiJZ