Make it easier to test new optimizers implemented: no need to change test file manually (#90)

* add helper function get_all_optimizers() in test_optimizers.py

* remove unused import
This commit is contained in:
ShiJZ 2023-12-09 13:39:08 +08:00 committed by GitHub
parent cb9e585b8e
commit 08d51bf232
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,12 +1,22 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import unittest import unittest
import inspect
import mlx.core as mx import mlx.core as mx
import mlx.optimizers as opt import mlx.optimizers as opt
import mlx.utils import mlx.utils
import mlx_tests import mlx_tests
def get_all_optimizers():
classes = dict()
for name, obj in inspect.getmembers(opt):
if inspect.isclass(obj):
if obj.__name__ not in ["OptimizerState", "Optimizer"]:
classes[name] = obj
return classes
optimizers_dict = get_all_optimizers()
class TestOptimizers(mlx_tests.MLXTestCase): class TestOptimizers(mlx_tests.MLXTestCase):
def test_optimizers(self): def test_optimizers(self):
@ -16,7 +26,8 @@ class TestOptimizers(mlx_tests.MLXTestCase):
} }
grads = mlx.utils.tree_map(lambda x: mx.ones_like(x), params) grads = mlx.utils.tree_map(lambda x: mx.ones_like(x), params)
for optim in [opt.SGD(0.1), opt.Adam(0.1)]: for optim_class in optimizers_dict.values():
optim = optim_class(0.1)
update = optim.apply_gradients(grads, params) update = optim.apply_gradients(grads, params)
mx.eval(update) mx.eval(update)
equal_shape = mlx.utils.tree_map( equal_shape = mlx.utils.tree_map(