mlx/python/tests/test_optimizers.py
Hazem Essam 37fc9db82c
Added Adafactor (#415)
* Added adafactor

* Added Adafactor and ran pre-commit

* modified operations

* Added docstrings

* Switched two ops to fix a bug

* added underscore for internal functions and removed the plus sign in the last return statment

* Removed parameter rms from the optimizer state because its not needed

* Added simple MNIST test for Adafactor and temporary training log

* remove test files

* nits in docs

* comment nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-23 15:11:27 -08:00

63 lines
1.8 KiB
Python

# Copyright © 2023 Apple Inc.
import inspect
import unittest
import mlx.core as mx
import mlx.optimizers as opt
import mlx.utils
import mlx_tests
def get_all_optimizers():
classes = dict()
for name, obj in inspect.getmembers(opt):
if inspect.isclass(obj):
if obj.__name__ not in ["OptimizerState", "Optimizer"]:
classes[name] = obj
return classes
optimizers_dict = get_all_optimizers()
class TestOptimizers(mlx_tests.MLXTestCase):
def test_optimizers(self):
params = {
"first": [mx.zeros((10,)), mx.zeros((1,))],
"second": mx.zeros((1,)),
}
grads = mlx.utils.tree_map(lambda x: mx.ones_like(x), params)
for optim_class in optimizers_dict.values():
optim = optim_class(0.1)
update = optim.apply_gradients(grads, params)
mx.eval(update)
equal_shape = mlx.utils.tree_map(
lambda x, y: x.shape == y.shape, params, update
)
all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape))
self.assertTrue(all_equal)
def test_adafactor(self):
x = mx.zeros((5, 5))
grad = mx.ones_like(x)
optimizer = opt.Adafactor()
for _ in range(2):
xp = optimizer.apply_single(grad, x, optimizer.state)
self.assertEqual(xp.dtype, x.dtype)
self.assertEqual(xp.shape, x.shape)
x = mx.zeros((5, 5), mx.float16)
grad = mx.ones_like(x)
optimizer = opt.Adafactor()
for _ in range(2):
xp = optimizer.apply_single(grad, x, optimizer.state)
self.assertEqual(xp.dtype, x.dtype)
self.assertEqual(xp.shape, x.shape)
self.assertEqual(optimizer.state["step"], 2)
if __name__ == "__main__":
unittest.main()