mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	fix donation condition for compilation (#1237)
This commit is contained in:
		| @@ -707,6 +707,18 @@ class TestCompile(mlx_tests.MLXTestCase): | ||||
|         x = mx.array([0, float("inf"), 1], dtype=mx.bfloat16) | ||||
|         self.assertTrue(mx.array_equal(mx.compile(fn)(x), fn(x))) | ||||
|  | ||||
|     def test_max_into_equal(self): | ||||
|         x = mx.random.uniform(shape=(1, 2, 2)) | ||||
|         mx.eval(x) | ||||
|  | ||||
|         def fn(): | ||||
|             maxes = mx.max(x, axis=(1, 2), keepdims=True) | ||||
|             return x == maxes | ||||
|  | ||||
|         out = mx.compile(fn)() | ||||
|         expected = fn() | ||||
|         self.assertTrue(mx.array_equal(expected, out)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun