in place ops behave in place, fix some overloads (#411)

This commit is contained in:
Awni Hannun
2024-01-09 16:05:38 -08:00
committed by GitHub
parent 961435a243
commit 1d90a76d63
2 changed files with 202 additions and 19 deletions

View File

@@ -1313,6 +1313,65 @@ class TestArray(mlx_tests.MLXTestCase):
rtol=0,
)
def test_logical_overloads(self):
with self.assertRaises(ValueError):
mx.array(1.0) & mx.array(1)
with self.assertRaises(ValueError):
mx.array(1.0) | mx.array(1)
self.assertEqual((mx.array(True) & True).item(), True)
self.assertEqual((mx.array(True) & False).item(), False)
self.assertEqual((mx.array(True) | False).item(), True)
self.assertEqual((mx.array(False) | False).item(), False)
self.assertEqual((~mx.array(False)).item(), True)
def test_inplace(self):
iops = [
"__iadd__",
"__isub__",
"__imul__",
"__ifloordiv__",
"__imod__",
"__ipow__",
]
for op in iops:
a = mx.array([1, 2, 3])
a_np = np.array(a)
b = a
b = getattr(a, op)(3)
self.assertTrue(mx.array_equal(a, b))
out_np = getattr(a_np, op)(3)
self.assertTrue(np.array_equal(out_np, a))
with self.assertRaises(ValueError):
a = mx.array([1])
a /= 1
a = mx.array([2.0])
b = a
b /= 2
self.assertEqual(b.item(), 1.0)
self.assertEqual(b.item(), a.item())
a = mx.array(True)
b = a
b &= False
self.assertEqual(b.item(), False)
self.assertEqual(b.item(), a.item())
a = mx.array(False)
b = a
b |= True
self.assertEqual(b.item(), True)
self.assertEqual(b.item(), a.item())
# In-place matmul on its own
a = mx.array([[1.0, 2.0], [3.0, 4.0]])
b = a
b @= a
self.assertTrue(mx.array_equal(a, b))
if __name__ == "__main__":
unittest.main()