mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00

* Added adafactor * Added Adafactor and ran pre-commit * modified operations * Added docstrings * Switched two ops to fix a bug * added underscore for internal functions and removed the plus sign in the last return statment * Removed parameter rms from the optimizer state because its not needed * Added simple MNIST test for Adafactor and temporary training log * remove test files * nits in docs * comment nit --------- Co-authored-by: Awni Hannun <awni@apple.com>
63 lines
1.8 KiB
Python
63 lines
1.8 KiB
Python
# Copyright © 2023 Apple Inc.
|
|
|
|
import inspect
|
|
import unittest
|
|
|
|
import mlx.core as mx
|
|
import mlx.optimizers as opt
|
|
import mlx.utils
|
|
import mlx_tests
|
|
|
|
|
|
def get_all_optimizers():
|
|
classes = dict()
|
|
for name, obj in inspect.getmembers(opt):
|
|
if inspect.isclass(obj):
|
|
if obj.__name__ not in ["OptimizerState", "Optimizer"]:
|
|
classes[name] = obj
|
|
return classes
|
|
|
|
|
|
optimizers_dict = get_all_optimizers()
|
|
|
|
|
|
class TestOptimizers(mlx_tests.MLXTestCase):
|
|
def test_optimizers(self):
|
|
params = {
|
|
"first": [mx.zeros((10,)), mx.zeros((1,))],
|
|
"second": mx.zeros((1,)),
|
|
}
|
|
grads = mlx.utils.tree_map(lambda x: mx.ones_like(x), params)
|
|
|
|
for optim_class in optimizers_dict.values():
|
|
optim = optim_class(0.1)
|
|
update = optim.apply_gradients(grads, params)
|
|
mx.eval(update)
|
|
equal_shape = mlx.utils.tree_map(
|
|
lambda x, y: x.shape == y.shape, params, update
|
|
)
|
|
all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape))
|
|
self.assertTrue(all_equal)
|
|
|
|
def test_adafactor(self):
|
|
x = mx.zeros((5, 5))
|
|
grad = mx.ones_like(x)
|
|
optimizer = opt.Adafactor()
|
|
for _ in range(2):
|
|
xp = optimizer.apply_single(grad, x, optimizer.state)
|
|
self.assertEqual(xp.dtype, x.dtype)
|
|
self.assertEqual(xp.shape, x.shape)
|
|
|
|
x = mx.zeros((5, 5), mx.float16)
|
|
grad = mx.ones_like(x)
|
|
optimizer = opt.Adafactor()
|
|
for _ in range(2):
|
|
xp = optimizer.apply_single(grad, x, optimizer.state)
|
|
self.assertEqual(xp.dtype, x.dtype)
|
|
self.assertEqual(xp.shape, x.shape)
|
|
self.assertEqual(optimizer.state["step"], 2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|