mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	in place ops behave in place, fix some overloads (#411)
This commit is contained in:
		| @@ -1313,6 +1313,65 @@ class TestArray(mlx_tests.MLXTestCase): | ||||
|                     rtol=0, | ||||
|                 ) | ||||
|  | ||||
|     def test_logical_overloads(self): | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.array(1.0) & mx.array(1) | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.array(1.0) | mx.array(1) | ||||
|  | ||||
|         self.assertEqual((mx.array(True) & True).item(), True) | ||||
|         self.assertEqual((mx.array(True) & False).item(), False) | ||||
|         self.assertEqual((mx.array(True) | False).item(), True) | ||||
|         self.assertEqual((mx.array(False) | False).item(), False) | ||||
|         self.assertEqual((~mx.array(False)).item(), True) | ||||
|  | ||||
|     def test_inplace(self): | ||||
|         iops = [ | ||||
|             "__iadd__", | ||||
|             "__isub__", | ||||
|             "__imul__", | ||||
|             "__ifloordiv__", | ||||
|             "__imod__", | ||||
|             "__ipow__", | ||||
|         ] | ||||
|  | ||||
|         for op in iops: | ||||
|             a = mx.array([1, 2, 3]) | ||||
|             a_np = np.array(a) | ||||
|             b = a | ||||
|             b = getattr(a, op)(3) | ||||
|             self.assertTrue(mx.array_equal(a, b)) | ||||
|             out_np = getattr(a_np, op)(3) | ||||
|             self.assertTrue(np.array_equal(out_np, a)) | ||||
|  | ||||
|         with self.assertRaises(ValueError): | ||||
|             a = mx.array([1]) | ||||
|             a /= 1 | ||||
|  | ||||
|         a = mx.array([2.0]) | ||||
|         b = a | ||||
|         b /= 2 | ||||
|         self.assertEqual(b.item(), 1.0) | ||||
|         self.assertEqual(b.item(), a.item()) | ||||
|  | ||||
|         a = mx.array(True) | ||||
|         b = a | ||||
|         b &= False | ||||
|         self.assertEqual(b.item(), False) | ||||
|         self.assertEqual(b.item(), a.item()) | ||||
|  | ||||
|         a = mx.array(False) | ||||
|         b = a | ||||
|         b |= True | ||||
|         self.assertEqual(b.item(), True) | ||||
|         self.assertEqual(b.item(), a.item()) | ||||
|  | ||||
|         # In-place matmul on its own | ||||
|         a = mx.array([[1.0, 2.0], [3.0, 4.0]]) | ||||
|         b = a | ||||
|         b @= a | ||||
|         self.assertTrue(mx.array_equal(a, b)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun