mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 01:18:12 +08:00
in place ops behave in place, fix some overloads (#411)
This commit is contained in:
@@ -1313,6 +1313,65 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
rtol=0,
|
||||
)
|
||||
|
||||
def test_logical_overloads(self):
|
||||
with self.assertRaises(ValueError):
|
||||
mx.array(1.0) & mx.array(1)
|
||||
with self.assertRaises(ValueError):
|
||||
mx.array(1.0) | mx.array(1)
|
||||
|
||||
self.assertEqual((mx.array(True) & True).item(), True)
|
||||
self.assertEqual((mx.array(True) & False).item(), False)
|
||||
self.assertEqual((mx.array(True) | False).item(), True)
|
||||
self.assertEqual((mx.array(False) | False).item(), False)
|
||||
self.assertEqual((~mx.array(False)).item(), True)
|
||||
|
||||
def test_inplace(self):
|
||||
iops = [
|
||||
"__iadd__",
|
||||
"__isub__",
|
||||
"__imul__",
|
||||
"__ifloordiv__",
|
||||
"__imod__",
|
||||
"__ipow__",
|
||||
]
|
||||
|
||||
for op in iops:
|
||||
a = mx.array([1, 2, 3])
|
||||
a_np = np.array(a)
|
||||
b = a
|
||||
b = getattr(a, op)(3)
|
||||
self.assertTrue(mx.array_equal(a, b))
|
||||
out_np = getattr(a_np, op)(3)
|
||||
self.assertTrue(np.array_equal(out_np, a))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
a = mx.array([1])
|
||||
a /= 1
|
||||
|
||||
a = mx.array([2.0])
|
||||
b = a
|
||||
b /= 2
|
||||
self.assertEqual(b.item(), 1.0)
|
||||
self.assertEqual(b.item(), a.item())
|
||||
|
||||
a = mx.array(True)
|
||||
b = a
|
||||
b &= False
|
||||
self.assertEqual(b.item(), False)
|
||||
self.assertEqual(b.item(), a.item())
|
||||
|
||||
a = mx.array(False)
|
||||
b = a
|
||||
b |= True
|
||||
self.assertEqual(b.item(), True)
|
||||
self.assertEqual(b.item(), a.item())
|
||||
|
||||
# In-place matmul on its own
|
||||
a = mx.array([[1.0, 2.0], [3.0, 4.0]])
|
||||
b = a
|
||||
b @= a
|
||||
self.assertTrue(mx.array_equal(a, b))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user