mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Some fixes in cache / thread safety (#777)
* some fixes in cache / thread safety * speed up no cache case * fix opt test * optimizer docs * otpimizer docs * fix adafactor * fix adafactor
This commit is contained in:
@@ -299,16 +299,16 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
class TestSchedulers(unittest.TestCase):
|
||||
def test_decay_lr(self):
|
||||
for optim_class in optimizers_dict.values():
|
||||
lr_schedule = opt.step_decay(1e-1, 0.9, 1000)
|
||||
lr_schedule = opt.step_decay(1e-1, 0.9, 1)
|
||||
optimizer = optim_class(learning_rate=lr_schedule)
|
||||
|
||||
params = {"w": mx.ones((5, 5))}
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
for it in range(10):
|
||||
optimizer.apply_gradients(grads, params)
|
||||
expected_lr = 0.1 * (0.9**it)
|
||||
self.assertAlmostEqual(optimizer.learning_rate, expected_lr, delta=1e-7)
|
||||
return optimizer.apply_gradients(grads, params)
|
||||
|
||||
def test_step_decay(self):
|
||||
lr_schedule = opt.step_decay(1e-1, 0.9, 1000)
|
||||
|
Reference in New Issue
Block a user