mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	add test
This commit is contained in:
		@@ -1134,6 +1134,30 @@ class TestCompile(mlx_tests.MLXTestCase):
 | 
			
		||||
        a = fun2(mx.array(-1.0))
 | 
			
		||||
        self.assertEqual(a.item(), 1.0)
 | 
			
		||||
 | 
			
		||||
    def test_multiple_compile_same_capture(self):
 | 
			
		||||
        def fun(do_compile):
 | 
			
		||||
            t = mx.ones((10,))
 | 
			
		||||
            u = (1.0 - t) * 0.0 + t * 3.0
 | 
			
		||||
 | 
			
		||||
            o = mx.ones((6,))
 | 
			
		||||
            b = o[:, None] * u
 | 
			
		||||
 | 
			
		||||
            c = b * mx.ones_like(u)
 | 
			
		||||
 | 
			
		||||
            a = mx.ones((6,))
 | 
			
		||||
            if do_compile:
 | 
			
		||||
                d = mx.compile(lambda x: x @ b)(a)
 | 
			
		||||
                e = mx.compile(lambda x: x @ c.T)(d)
 | 
			
		||||
            else:
 | 
			
		||||
                d = a @ b
 | 
			
		||||
                e = d @ c.T
 | 
			
		||||
            return e
 | 
			
		||||
 | 
			
		||||
        out = fun(True)
 | 
			
		||||
        mx.eval(out)
 | 
			
		||||
        expected = fun(False)
 | 
			
		||||
        self.assertTrue(mx.allclose(out, expected))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    mlx_tests.MLXTestRunner()
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user