fix inplace to not make a shallow copy (#804)

This commit is contained in:
Awni Hannun 2024-03-07 09:34:11 -08:00 committed by GitHub
parent f512b905c7
commit b7588fd5d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 10 deletions

View File

@ -802,7 +802,7 @@ void init_array(py::module_& m) {
"other"_a) "other"_a)
.def( .def(
"__iadd__", "__iadd__",
[](array& a, const ScalarOrArray v) { [](array& a, const ScalarOrArray v) -> array& {
a.overwrite_descriptor(add(a, to_array(v, a.dtype()))); a.overwrite_descriptor(add(a, to_array(v, a.dtype())));
return a; return a;
}, },
@ -821,7 +821,7 @@ void init_array(py::module_& m) {
"other"_a) "other"_a)
.def( .def(
"__isub__", "__isub__",
[](array& a, const ScalarOrArray v) { [](array& a, const ScalarOrArray v) -> array& {
a.overwrite_descriptor(subtract(a, to_array(v, a.dtype()))); a.overwrite_descriptor(subtract(a, to_array(v, a.dtype())));
return a; return a;
}, },
@ -840,7 +840,7 @@ void init_array(py::module_& m) {
"other"_a) "other"_a)
.def( .def(
"__imul__", "__imul__",
[](array& a, const ScalarOrArray v) { [](array& a, const ScalarOrArray v) -> array& {
a.overwrite_descriptor(multiply(a, to_array(v, a.dtype()))); a.overwrite_descriptor(multiply(a, to_array(v, a.dtype())));
return a; return a;
}, },
@ -859,7 +859,7 @@ void init_array(py::module_& m) {
"other"_a) "other"_a)
.def( .def(
"__itruediv__", "__itruediv__",
[](array& a, const ScalarOrArray v) { [](array& a, const ScalarOrArray v) -> array& {
if (!is_floating_point(a.dtype())) { if (!is_floating_point(a.dtype())) {
throw std::invalid_argument( throw std::invalid_argument(
"In place division cannot cast to non-floating point type."); "In place division cannot cast to non-floating point type.");
@ -894,7 +894,7 @@ void init_array(py::module_& m) {
"other"_a) "other"_a)
.def( .def(
"__ifloordiv__", "__ifloordiv__",
[](array& a, const ScalarOrArray v) { [](array& a, const ScalarOrArray v) -> array& {
a.overwrite_descriptor(floor_divide(a, to_array(v, a.dtype()))); a.overwrite_descriptor(floor_divide(a, to_array(v, a.dtype())));
return a; return a;
}, },
@ -914,7 +914,7 @@ void init_array(py::module_& m) {
"other"_a) "other"_a)
.def( .def(
"__imod__", "__imod__",
[](array& a, const ScalarOrArray v) { [](array& a, const ScalarOrArray v) -> array& {
a.overwrite_descriptor(remainder(a, to_array(v, a.dtype()))); a.overwrite_descriptor(remainder(a, to_array(v, a.dtype())));
return a; return a;
}, },
@ -980,7 +980,7 @@ void init_array(py::module_& m) {
"other"_a) "other"_a)
.def( .def(
"__imatmul__", "__imatmul__",
[](array& a, array& other) { [](array& a, array& other) -> array& {
a.overwrite_descriptor(matmul(a, other)); a.overwrite_descriptor(matmul(a, other));
return a; return a;
}, },
@ -999,7 +999,7 @@ void init_array(py::module_& m) {
"other"_a) "other"_a)
.def( .def(
"__ipow__", "__ipow__",
[](array& a, const ScalarOrArray v) { [](array& a, const ScalarOrArray v) -> array& {
a.overwrite_descriptor(power(a, to_array(v, a.dtype()))); a.overwrite_descriptor(power(a, to_array(v, a.dtype())));
return a; return a;
}, },
@ -1034,7 +1034,7 @@ void init_array(py::module_& m) {
"other"_a) "other"_a)
.def( .def(
"__iand__", "__iand__",
[](array& a, const ScalarOrArray v) { [](array& a, const ScalarOrArray v) -> array& {
auto b = to_array(v, a.dtype()); auto b = to_array(v, a.dtype());
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) { if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
throw std::invalid_argument( throw std::invalid_argument(
@ -1065,7 +1065,7 @@ void init_array(py::module_& m) {
"other"_a) "other"_a)
.def( .def(
"__ior__", "__ior__",
[](array& a, const ScalarOrArray v) { [](array& a, const ScalarOrArray v) -> array& {
auto b = to_array(v, a.dtype()); auto b = to_array(v, a.dtype());
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) { if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
throw std::invalid_argument( throw std::invalid_argument(

View File

@ -1442,6 +1442,21 @@ class TestArray(mlx_tests.MLXTestCase):
b @= a b @= a
self.assertTrue(mx.array_equal(a, b)) self.assertTrue(mx.array_equal(a, b))
def test_inplace_preserves_ids(self):
a = mx.array([1.0])
orig_id = id(a)
a += mx.array(2.0)
self.assertEqual(id(a), orig_id)
a[0] = 2.0
self.assertEqual(id(a), orig_id)
a -= mx.array(3.0)
self.assertEqual(id(a), orig_id)
a *= mx.array(3.0)
self.assertEqual(id(a), orig_id)
def test_load_from_pickled_np(self): def test_load_from_pickled_np(self):
a = np.array([1, 2, 3], dtype=np.int32) a = np.array([1, 2, 3], dtype=np.int32)
b = pickle.loads(pickle.dumps(a)) b = pickle.loads(pickle.dumps(a))