From b7588fd5d7fcd7181297eb8213bdc8e181d21edf Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 7 Mar 2024 09:34:11 -0800 Subject: [PATCH] fix inplace to not make a shallow copy (#804) --- python/src/array.cpp | 20 ++++++++++---------- python/tests/test_array.py | 15 +++++++++++++++ 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/python/src/array.cpp b/python/src/array.cpp index 838970d84..dd8ad89e5 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -802,7 +802,7 @@ void init_array(py::module_& m) { "other"_a) .def( "__iadd__", - [](array& a, const ScalarOrArray v) { + [](array& a, const ScalarOrArray v) -> array& { a.overwrite_descriptor(add(a, to_array(v, a.dtype()))); return a; }, @@ -821,7 +821,7 @@ void init_array(py::module_& m) { "other"_a) .def( "__isub__", - [](array& a, const ScalarOrArray v) { + [](array& a, const ScalarOrArray v) -> array& { a.overwrite_descriptor(subtract(a, to_array(v, a.dtype()))); return a; }, @@ -840,7 +840,7 @@ void init_array(py::module_& m) { "other"_a) .def( "__imul__", - [](array& a, const ScalarOrArray v) { + [](array& a, const ScalarOrArray v) -> array& { a.overwrite_descriptor(multiply(a, to_array(v, a.dtype()))); return a; }, @@ -859,7 +859,7 @@ void init_array(py::module_& m) { "other"_a) .def( "__itruediv__", - [](array& a, const ScalarOrArray v) { + [](array& a, const ScalarOrArray v) -> array& { if (!is_floating_point(a.dtype())) { throw std::invalid_argument( "In place division cannot cast to non-floating point type."); @@ -894,7 +894,7 @@ void init_array(py::module_& m) { "other"_a) .def( "__ifloordiv__", - [](array& a, const ScalarOrArray v) { + [](array& a, const ScalarOrArray v) -> array& { a.overwrite_descriptor(floor_divide(a, to_array(v, a.dtype()))); return a; }, @@ -914,7 +914,7 @@ void init_array(py::module_& m) { "other"_a) .def( "__imod__", - [](array& a, const ScalarOrArray v) { + [](array& a, const ScalarOrArray v) -> array& { a.overwrite_descriptor(remainder(a, to_array(v, a.dtype()))); return a; }, @@ -980,7 +980,7 @@ void init_array(py::module_& m) { "other"_a) .def( "__imatmul__", - [](array& a, array& other) { + [](array& a, array& other) -> array& { a.overwrite_descriptor(matmul(a, other)); return a; }, @@ -999,7 +999,7 @@ void init_array(py::module_& m) { "other"_a) .def( "__ipow__", - [](array& a, const ScalarOrArray v) { + [](array& a, const ScalarOrArray v) -> array& { a.overwrite_descriptor(power(a, to_array(v, a.dtype()))); return a; }, @@ -1034,7 +1034,7 @@ void init_array(py::module_& m) { "other"_a) .def( "__iand__", - [](array& a, const ScalarOrArray v) { + [](array& a, const ScalarOrArray v) -> array& { auto b = to_array(v, a.dtype()); if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) { throw std::invalid_argument( @@ -1065,7 +1065,7 @@ void init_array(py::module_& m) { "other"_a) .def( "__ior__", - [](array& a, const ScalarOrArray v) { + [](array& a, const ScalarOrArray v) -> array& { auto b = to_array(v, a.dtype()); if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) { throw std::invalid_argument( diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 4fbb2d0ae..07c7bd18d 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1442,6 +1442,21 @@ class TestArray(mlx_tests.MLXTestCase): b @= a self.assertTrue(mx.array_equal(a, b)) + def test_inplace_preserves_ids(self): + a = mx.array([1.0]) + orig_id = id(a) + a += mx.array(2.0) + self.assertEqual(id(a), orig_id) + + a[0] = 2.0 + self.assertEqual(id(a), orig_id) + + a -= mx.array(3.0) + self.assertEqual(id(a), orig_id) + + a *= mx.array(3.0) + self.assertEqual(id(a), orig_id) + def test_load_from_pickled_np(self): a = np.array([1, 2, 3], dtype=np.int32) b = pickle.loads(pickle.dumps(a))