fix inplace to not make a shallow copy (#804)

This commit is contained in:
Awni Hannun
2024-03-07 09:34:11 -08:00
committed by GitHub
parent f512b905c7
commit b7588fd5d7
2 changed files with 25 additions and 10 deletions

View File

@@ -1442,6 +1442,21 @@ class TestArray(mlx_tests.MLXTestCase):
b @= a
self.assertTrue(mx.array_equal(a, b))
def test_inplace_preserves_ids(self):
a = mx.array([1.0])
orig_id = id(a)
a += mx.array(2.0)
self.assertEqual(id(a), orig_id)
a[0] = 2.0
self.assertEqual(id(a), orig_id)
a -= mx.array(3.0)
self.assertEqual(id(a), orig_id)
a *= mx.array(3.0)
self.assertEqual(id(a), orig_id)
def test_load_from_pickled_np(self):
a = np.array([1, 2, 3], dtype=np.int32)
b = pickle.loads(pickle.dumps(a))