From 08d51bf23216bdffbe82006ac2ec09312b8414d9 Mon Sep 17 00:00:00 2001 From: ShiJZ <100902397+JingzheShi@users.noreply.github.com> Date: Sat, 9 Dec 2023 13:39:08 +0800 Subject: [PATCH] Make it easier to test new optimizers implemented: no need to change test file manually (#90) * add helper function get_all_optimizers() in test_optimizers.py * remove unused import --- python/tests/test_optimizers.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/python/tests/test_optimizers.py b/python/tests/test_optimizers.py index e136f6b3e..de9ab68b6 100644 --- a/python/tests/test_optimizers.py +++ b/python/tests/test_optimizers.py @@ -1,12 +1,22 @@ # Copyright © 2023 Apple Inc. import unittest +import inspect import mlx.core as mx import mlx.optimizers as opt import mlx.utils import mlx_tests +def get_all_optimizers(): + classes = dict() + for name, obj in inspect.getmembers(opt): + if inspect.isclass(obj): + if obj.__name__ not in ["OptimizerState", "Optimizer"]: + classes[name] = obj + return classes + +optimizers_dict = get_all_optimizers() class TestOptimizers(mlx_tests.MLXTestCase): def test_optimizers(self): @@ -16,7 +26,8 @@ class TestOptimizers(mlx_tests.MLXTestCase): } grads = mlx.utils.tree_map(lambda x: mx.ones_like(x), params) - for optim in [opt.SGD(0.1), opt.Adam(0.1)]: + for optim_class in optimizers_dict.values(): + optim = optim_class(0.1) update = optim.apply_gradients(grads, params) mx.eval(update) equal_shape = mlx.utils.tree_map(