mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	Some fixes in cache / thread safety (#777)
* some fixes in cache / thread safety * speed up no cache case * fix opt test * optimizer docs * otpimizer docs * fix adafactor * fix adafactor
This commit is contained in:
		@@ -299,16 +299,16 @@ class TestOptimizers(mlx_tests.MLXTestCase):
 | 
			
		||||
class TestSchedulers(unittest.TestCase):
 | 
			
		||||
    def test_decay_lr(self):
 | 
			
		||||
        for optim_class in optimizers_dict.values():
 | 
			
		||||
            lr_schedule = opt.step_decay(1e-1, 0.9, 1000)
 | 
			
		||||
            lr_schedule = opt.step_decay(1e-1, 0.9, 1)
 | 
			
		||||
            optimizer = optim_class(learning_rate=lr_schedule)
 | 
			
		||||
 | 
			
		||||
            params = {"w": mx.ones((5, 5))}
 | 
			
		||||
            grads = tree_map(lambda x: mx.ones_like(x), params)
 | 
			
		||||
 | 
			
		||||
            for it in range(10):
 | 
			
		||||
                optimizer.apply_gradients(grads, params)
 | 
			
		||||
                expected_lr = 0.1 * (0.9**it)
 | 
			
		||||
                self.assertAlmostEqual(optimizer.learning_rate, expected_lr, delta=1e-7)
 | 
			
		||||
                return optimizer.apply_gradients(grads, params)
 | 
			
		||||
 | 
			
		||||
    def test_step_decay(self):
 | 
			
		||||
        lr_schedule = opt.step_decay(1e-1, 0.9, 1000)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user