mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Added Adafactor (#415)
* Added adafactor * Added Adafactor and ran pre-commit * modified operations * Added docstrings * Switched two ops to fix a bug * added underscore for internal functions and removed the plus sign in the last return statment * Removed parameter rms from the optimizer state because its not needed * Added simple MNIST test for Adafactor and temporary training log * remove test files * nits in docs * comment nit --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -39,6 +39,24 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape))
|
||||
self.assertTrue(all_equal)
|
||||
|
||||
def test_adafactor(self):
|
||||
x = mx.zeros((5, 5))
|
||||
grad = mx.ones_like(x)
|
||||
optimizer = opt.Adafactor()
|
||||
for _ in range(2):
|
||||
xp = optimizer.apply_single(grad, x, optimizer.state)
|
||||
self.assertEqual(xp.dtype, x.dtype)
|
||||
self.assertEqual(xp.shape, x.shape)
|
||||
|
||||
x = mx.zeros((5, 5), mx.float16)
|
||||
grad = mx.ones_like(x)
|
||||
optimizer = opt.Adafactor()
|
||||
for _ in range(2):
|
||||
xp = optimizer.apply_single(grad, x, optimizer.state)
|
||||
self.assertEqual(xp.dtype, x.dtype)
|
||||
self.assertEqual(xp.shape, x.shape)
|
||||
self.assertEqual(optimizer.state["step"], 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user