Added Adafactor (#415)

* Added adafactor

* Added Adafactor and ran pre-commit

* modified operations

* Added docstrings

* Switched two ops to fix a bug

* added underscore for internal functions and removed the plus sign in the last return statment

* Removed parameter rms from the optimizer state because its not needed

* Added simple MNIST test for Adafactor and temporary training log

* remove test files

* nits in docs

* comment nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Hazem Essam
2024-01-24 01:11:27 +02:00
committed by GitHub
parent 755dcf6137
commit 37fc9db82c
3 changed files with 164 additions and 10 deletions

View File

@@ -39,6 +39,24 @@ class TestOptimizers(mlx_tests.MLXTestCase):
all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape))
self.assertTrue(all_equal)
def test_adafactor(self):
x = mx.zeros((5, 5))
grad = mx.ones_like(x)
optimizer = opt.Adafactor()
for _ in range(2):
xp = optimizer.apply_single(grad, x, optimizer.state)
self.assertEqual(xp.dtype, x.dtype)
self.assertEqual(xp.shape, x.shape)
x = mx.zeros((5, 5), mx.float16)
grad = mx.ones_like(x)
optimizer = opt.Adafactor()
for _ in range(2):
xp = optimizer.apply_single(grad, x, optimizer.state)
self.assertEqual(xp.dtype, x.dtype)
self.assertEqual(xp.shape, x.shape)
self.assertEqual(optimizer.state["step"], 2)
if __name__ == "__main__":
unittest.main()