mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Make it easier to test new optimizers implemented: no need to change test file manually (#90)
* add helper function get_all_optimizers() in test_optimizers.py * remove unused import
This commit is contained in:
parent
cb9e585b8e
commit
08d51bf232
@ -1,12 +1,22 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import unittest
|
||||
import inspect
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.optimizers as opt
|
||||
import mlx.utils
|
||||
import mlx_tests
|
||||
|
||||
def get_all_optimizers():
|
||||
classes = dict()
|
||||
for name, obj in inspect.getmembers(opt):
|
||||
if inspect.isclass(obj):
|
||||
if obj.__name__ not in ["OptimizerState", "Optimizer"]:
|
||||
classes[name] = obj
|
||||
return classes
|
||||
|
||||
optimizers_dict = get_all_optimizers()
|
||||
|
||||
class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
def test_optimizers(self):
|
||||
@ -16,7 +26,8 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
}
|
||||
grads = mlx.utils.tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
for optim in [opt.SGD(0.1), opt.Adam(0.1)]:
|
||||
for optim_class in optimizers_dict.values():
|
||||
optim = optim_class(0.1)
|
||||
update = optim.apply_gradients(grads, params)
|
||||
mx.eval(update)
|
||||
equal_shape = mlx.utils.tree_map(
|
||||
|
Loading…
Reference in New Issue
Block a user