Some fixes in cache / thread safety (#777)

* some fixes in cache / thread safety

* speed up no cache case

* fix opt test

* optimizer docs

* otpimizer docs

* fix adafactor

* fix adafactor
This commit is contained in:
Awni Hannun
2024-03-05 13:30:50 -08:00
committed by GitHub
parent 859ae15a54
commit cbcf44a4ca
4 changed files with 60 additions and 41 deletions

View File

@@ -299,16 +299,16 @@ class TestOptimizers(mlx_tests.MLXTestCase):
class TestSchedulers(unittest.TestCase):
def test_decay_lr(self):
for optim_class in optimizers_dict.values():
lr_schedule = opt.step_decay(1e-1, 0.9, 1000)
lr_schedule = opt.step_decay(1e-1, 0.9, 1)
optimizer = optim_class(learning_rate=lr_schedule)
params = {"w": mx.ones((5, 5))}
grads = tree_map(lambda x: mx.ones_like(x), params)
for it in range(10):
optimizer.apply_gradients(grads, params)
expected_lr = 0.1 * (0.9**it)
self.assertAlmostEqual(optimizer.learning_rate, expected_lr, delta=1e-7)
return optimizer.apply_gradients(grads, params)
def test_step_decay(self):
lr_schedule = opt.step_decay(1e-1, 0.9, 1000)