mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	Layer norm grad fix donation bug (#941)
* add layer norm grad test * Fix donation bug in layernorm vjp --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		
				
					committed by
					
						
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							9cbff5ec1d
						
					
				
				
					commit
					110d9b149d
				
			@@ -375,6 +375,21 @@ class TestFast(mlx_tests.MLXTestCase):
 | 
			
		||||
        self.assertLess(mx.abs(gb1).max(), 1e-9)
 | 
			
		||||
        self.assertLess(mx.abs(gb2).max(), 1e-9)
 | 
			
		||||
 | 
			
		||||
    def test_layer_norm_grad_params(self):
 | 
			
		||||
        eps = 1e-5
 | 
			
		||||
        f1 = lambda params, x: (layer_norm(x, params[0], params[1], eps)).sum()
 | 
			
		||||
        f2 = lambda params, x: (mx.fast.layer_norm(x, params[0], params[1], eps)).sum()
 | 
			
		||||
 | 
			
		||||
        w = mx.ones((8,))
 | 
			
		||||
        b = mx.zeros((8,))
 | 
			
		||||
        x = mx.random.normal(shape=(2, 2, 8))
 | 
			
		||||
        mx.eval(x, w, b)
 | 
			
		||||
 | 
			
		||||
        gw1, gb1 = mx.grad(f1)((w, b), x)
 | 
			
		||||
        gw2, gb2 = mx.grad(f2)((w, b), x)
 | 
			
		||||
        self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
 | 
			
		||||
        self.assertLess(mx.abs(gb1 - gb2).max() / mx.abs(gb1).mean(), 1e-5)
 | 
			
		||||
 | 
			
		||||
    def test_fast_transforms(self):
 | 
			
		||||
        x = mx.random.uniform(shape=(2, 2, 8))
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user