mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Fix flaky addmm tests (#2581)
This commit is contained in:
		@@ -702,7 +702,7 @@ class TestBlas(mlx_tests.MLXTestCase):
 | 
			
		||||
            b = mx.ones((5, 5))
 | 
			
		||||
            out = mx.addmm(a, b, a, beta=beta, alpha=alpha)
 | 
			
		||||
            expected = beta * a + alpha * (b @ a)
 | 
			
		||||
            self.assertTrue(mx.allclose(expected, out))
 | 
			
		||||
            self.assertTrue(mx.allclose(expected, out, atol=1e-5))
 | 
			
		||||
 | 
			
		||||
            # Broadcast c
 | 
			
		||||
            a = mx.ones((5, 5))
 | 
			
		||||
@@ -710,7 +710,7 @@ class TestBlas(mlx_tests.MLXTestCase):
 | 
			
		||||
            c = mx.ones((1, 5))
 | 
			
		||||
            out = mx.addmm(c, a, b, beta=beta, alpha=alpha)
 | 
			
		||||
            expected = beta * c + alpha * (a @ b)
 | 
			
		||||
            self.assertTrue(mx.allclose(expected, out))
 | 
			
		||||
            self.assertTrue(mx.allclose(expected, out, atol=1e-5))
 | 
			
		||||
 | 
			
		||||
    def test_addmm_grad(self):
 | 
			
		||||
        def make_ref_addmm(alpha, beta):
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user