mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 02:41:19 +08:00
fix inplace to not make a shallow copy (#804)
This commit is contained in:
parent
f512b905c7
commit
b7588fd5d7
@ -802,7 +802,7 @@ void init_array(py::module_& m) {
|
||||
"other"_a)
|
||||
.def(
|
||||
"__iadd__",
|
||||
[](array& a, const ScalarOrArray v) {
|
||||
[](array& a, const ScalarOrArray v) -> array& {
|
||||
a.overwrite_descriptor(add(a, to_array(v, a.dtype())));
|
||||
return a;
|
||||
},
|
||||
@ -821,7 +821,7 @@ void init_array(py::module_& m) {
|
||||
"other"_a)
|
||||
.def(
|
||||
"__isub__",
|
||||
[](array& a, const ScalarOrArray v) {
|
||||
[](array& a, const ScalarOrArray v) -> array& {
|
||||
a.overwrite_descriptor(subtract(a, to_array(v, a.dtype())));
|
||||
return a;
|
||||
},
|
||||
@ -840,7 +840,7 @@ void init_array(py::module_& m) {
|
||||
"other"_a)
|
||||
.def(
|
||||
"__imul__",
|
||||
[](array& a, const ScalarOrArray v) {
|
||||
[](array& a, const ScalarOrArray v) -> array& {
|
||||
a.overwrite_descriptor(multiply(a, to_array(v, a.dtype())));
|
||||
return a;
|
||||
},
|
||||
@ -859,7 +859,7 @@ void init_array(py::module_& m) {
|
||||
"other"_a)
|
||||
.def(
|
||||
"__itruediv__",
|
||||
[](array& a, const ScalarOrArray v) {
|
||||
[](array& a, const ScalarOrArray v) -> array& {
|
||||
if (!is_floating_point(a.dtype())) {
|
||||
throw std::invalid_argument(
|
||||
"In place division cannot cast to non-floating point type.");
|
||||
@ -894,7 +894,7 @@ void init_array(py::module_& m) {
|
||||
"other"_a)
|
||||
.def(
|
||||
"__ifloordiv__",
|
||||
[](array& a, const ScalarOrArray v) {
|
||||
[](array& a, const ScalarOrArray v) -> array& {
|
||||
a.overwrite_descriptor(floor_divide(a, to_array(v, a.dtype())));
|
||||
return a;
|
||||
},
|
||||
@ -914,7 +914,7 @@ void init_array(py::module_& m) {
|
||||
"other"_a)
|
||||
.def(
|
||||
"__imod__",
|
||||
[](array& a, const ScalarOrArray v) {
|
||||
[](array& a, const ScalarOrArray v) -> array& {
|
||||
a.overwrite_descriptor(remainder(a, to_array(v, a.dtype())));
|
||||
return a;
|
||||
},
|
||||
@ -980,7 +980,7 @@ void init_array(py::module_& m) {
|
||||
"other"_a)
|
||||
.def(
|
||||
"__imatmul__",
|
||||
[](array& a, array& other) {
|
||||
[](array& a, array& other) -> array& {
|
||||
a.overwrite_descriptor(matmul(a, other));
|
||||
return a;
|
||||
},
|
||||
@ -999,7 +999,7 @@ void init_array(py::module_& m) {
|
||||
"other"_a)
|
||||
.def(
|
||||
"__ipow__",
|
||||
[](array& a, const ScalarOrArray v) {
|
||||
[](array& a, const ScalarOrArray v) -> array& {
|
||||
a.overwrite_descriptor(power(a, to_array(v, a.dtype())));
|
||||
return a;
|
||||
},
|
||||
@ -1034,7 +1034,7 @@ void init_array(py::module_& m) {
|
||||
"other"_a)
|
||||
.def(
|
||||
"__iand__",
|
||||
[](array& a, const ScalarOrArray v) {
|
||||
[](array& a, const ScalarOrArray v) -> array& {
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
|
||||
throw std::invalid_argument(
|
||||
@ -1065,7 +1065,7 @@ void init_array(py::module_& m) {
|
||||
"other"_a)
|
||||
.def(
|
||||
"__ior__",
|
||||
[](array& a, const ScalarOrArray v) {
|
||||
[](array& a, const ScalarOrArray v) -> array& {
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
|
||||
throw std::invalid_argument(
|
||||
|
@ -1442,6 +1442,21 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
b @= a
|
||||
self.assertTrue(mx.array_equal(a, b))
|
||||
|
||||
def test_inplace_preserves_ids(self):
|
||||
a = mx.array([1.0])
|
||||
orig_id = id(a)
|
||||
a += mx.array(2.0)
|
||||
self.assertEqual(id(a), orig_id)
|
||||
|
||||
a[0] = 2.0
|
||||
self.assertEqual(id(a), orig_id)
|
||||
|
||||
a -= mx.array(3.0)
|
||||
self.assertEqual(id(a), orig_id)
|
||||
|
||||
a *= mx.array(3.0)
|
||||
self.assertEqual(id(a), orig_id)
|
||||
|
||||
def test_load_from_pickled_np(self):
|
||||
a = np.array([1, 2, 3], dtype=np.int32)
|
||||
b = pickle.loads(pickle.dumps(a))
|
||||
|
Loading…
Reference in New Issue
Block a user