Make it easier to test new optimizers implemented: no need to change test file manually (#90)

* add helper function get_all_optimizers() in test_optimizers.py

* remove unused import
This commit is contained in:
ShiJZ 2023-12-09 13:39:08 +08:00 committed by GitHub
parent cb9e585b8e
commit 08d51bf232
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,12 +1,22 @@
# Copyright © 2023 Apple Inc.
import unittest
import inspect
import mlx.core as mx
import mlx.optimizers as opt
import mlx.utils
import mlx_tests
def get_all_optimizers():
classes = dict()
for name, obj in inspect.getmembers(opt):
if inspect.isclass(obj):
if obj.__name__ not in ["OptimizerState", "Optimizer"]:
classes[name] = obj
return classes
optimizers_dict = get_all_optimizers()
class TestOptimizers(mlx_tests.MLXTestCase):
def test_optimizers(self):
@ -16,7 +26,8 @@ class TestOptimizers(mlx_tests.MLXTestCase):
}
grads = mlx.utils.tree_map(lambda x: mx.ones_like(x), params)
for optim in [opt.SGD(0.1), opt.Adam(0.1)]:
for optim_class in optimizers_dict.values():
optim = optim_class(0.1)
update = optim.apply_gradients(grads, params)
mx.eval(update)
equal_shape = mlx.utils.tree_map(