mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	fix inplace to not make a shallow copy (#804)
This commit is contained in:
		@@ -802,7 +802,7 @@ void init_array(py::module_& m) {
 | 
				
			|||||||
          "other"_a)
 | 
					          "other"_a)
 | 
				
			||||||
      .def(
 | 
					      .def(
 | 
				
			||||||
          "__iadd__",
 | 
					          "__iadd__",
 | 
				
			||||||
          [](array& a, const ScalarOrArray v) {
 | 
					          [](array& a, const ScalarOrArray v) -> array& {
 | 
				
			||||||
            a.overwrite_descriptor(add(a, to_array(v, a.dtype())));
 | 
					            a.overwrite_descriptor(add(a, to_array(v, a.dtype())));
 | 
				
			||||||
            return a;
 | 
					            return a;
 | 
				
			||||||
          },
 | 
					          },
 | 
				
			||||||
@@ -821,7 +821,7 @@ void init_array(py::module_& m) {
 | 
				
			|||||||
          "other"_a)
 | 
					          "other"_a)
 | 
				
			||||||
      .def(
 | 
					      .def(
 | 
				
			||||||
          "__isub__",
 | 
					          "__isub__",
 | 
				
			||||||
          [](array& a, const ScalarOrArray v) {
 | 
					          [](array& a, const ScalarOrArray v) -> array& {
 | 
				
			||||||
            a.overwrite_descriptor(subtract(a, to_array(v, a.dtype())));
 | 
					            a.overwrite_descriptor(subtract(a, to_array(v, a.dtype())));
 | 
				
			||||||
            return a;
 | 
					            return a;
 | 
				
			||||||
          },
 | 
					          },
 | 
				
			||||||
@@ -840,7 +840,7 @@ void init_array(py::module_& m) {
 | 
				
			|||||||
          "other"_a)
 | 
					          "other"_a)
 | 
				
			||||||
      .def(
 | 
					      .def(
 | 
				
			||||||
          "__imul__",
 | 
					          "__imul__",
 | 
				
			||||||
          [](array& a, const ScalarOrArray v) {
 | 
					          [](array& a, const ScalarOrArray v) -> array& {
 | 
				
			||||||
            a.overwrite_descriptor(multiply(a, to_array(v, a.dtype())));
 | 
					            a.overwrite_descriptor(multiply(a, to_array(v, a.dtype())));
 | 
				
			||||||
            return a;
 | 
					            return a;
 | 
				
			||||||
          },
 | 
					          },
 | 
				
			||||||
@@ -859,7 +859,7 @@ void init_array(py::module_& m) {
 | 
				
			|||||||
          "other"_a)
 | 
					          "other"_a)
 | 
				
			||||||
      .def(
 | 
					      .def(
 | 
				
			||||||
          "__itruediv__",
 | 
					          "__itruediv__",
 | 
				
			||||||
          [](array& a, const ScalarOrArray v) {
 | 
					          [](array& a, const ScalarOrArray v) -> array& {
 | 
				
			||||||
            if (!is_floating_point(a.dtype())) {
 | 
					            if (!is_floating_point(a.dtype())) {
 | 
				
			||||||
              throw std::invalid_argument(
 | 
					              throw std::invalid_argument(
 | 
				
			||||||
                  "In place division cannot cast to non-floating point type.");
 | 
					                  "In place division cannot cast to non-floating point type.");
 | 
				
			||||||
@@ -894,7 +894,7 @@ void init_array(py::module_& m) {
 | 
				
			|||||||
          "other"_a)
 | 
					          "other"_a)
 | 
				
			||||||
      .def(
 | 
					      .def(
 | 
				
			||||||
          "__ifloordiv__",
 | 
					          "__ifloordiv__",
 | 
				
			||||||
          [](array& a, const ScalarOrArray v) {
 | 
					          [](array& a, const ScalarOrArray v) -> array& {
 | 
				
			||||||
            a.overwrite_descriptor(floor_divide(a, to_array(v, a.dtype())));
 | 
					            a.overwrite_descriptor(floor_divide(a, to_array(v, a.dtype())));
 | 
				
			||||||
            return a;
 | 
					            return a;
 | 
				
			||||||
          },
 | 
					          },
 | 
				
			||||||
@@ -914,7 +914,7 @@ void init_array(py::module_& m) {
 | 
				
			|||||||
          "other"_a)
 | 
					          "other"_a)
 | 
				
			||||||
      .def(
 | 
					      .def(
 | 
				
			||||||
          "__imod__",
 | 
					          "__imod__",
 | 
				
			||||||
          [](array& a, const ScalarOrArray v) {
 | 
					          [](array& a, const ScalarOrArray v) -> array& {
 | 
				
			||||||
            a.overwrite_descriptor(remainder(a, to_array(v, a.dtype())));
 | 
					            a.overwrite_descriptor(remainder(a, to_array(v, a.dtype())));
 | 
				
			||||||
            return a;
 | 
					            return a;
 | 
				
			||||||
          },
 | 
					          },
 | 
				
			||||||
@@ -980,7 +980,7 @@ void init_array(py::module_& m) {
 | 
				
			|||||||
          "other"_a)
 | 
					          "other"_a)
 | 
				
			||||||
      .def(
 | 
					      .def(
 | 
				
			||||||
          "__imatmul__",
 | 
					          "__imatmul__",
 | 
				
			||||||
          [](array& a, array& other) {
 | 
					          [](array& a, array& other) -> array& {
 | 
				
			||||||
            a.overwrite_descriptor(matmul(a, other));
 | 
					            a.overwrite_descriptor(matmul(a, other));
 | 
				
			||||||
            return a;
 | 
					            return a;
 | 
				
			||||||
          },
 | 
					          },
 | 
				
			||||||
@@ -999,7 +999,7 @@ void init_array(py::module_& m) {
 | 
				
			|||||||
          "other"_a)
 | 
					          "other"_a)
 | 
				
			||||||
      .def(
 | 
					      .def(
 | 
				
			||||||
          "__ipow__",
 | 
					          "__ipow__",
 | 
				
			||||||
          [](array& a, const ScalarOrArray v) {
 | 
					          [](array& a, const ScalarOrArray v) -> array& {
 | 
				
			||||||
            a.overwrite_descriptor(power(a, to_array(v, a.dtype())));
 | 
					            a.overwrite_descriptor(power(a, to_array(v, a.dtype())));
 | 
				
			||||||
            return a;
 | 
					            return a;
 | 
				
			||||||
          },
 | 
					          },
 | 
				
			||||||
@@ -1034,7 +1034,7 @@ void init_array(py::module_& m) {
 | 
				
			|||||||
          "other"_a)
 | 
					          "other"_a)
 | 
				
			||||||
      .def(
 | 
					      .def(
 | 
				
			||||||
          "__iand__",
 | 
					          "__iand__",
 | 
				
			||||||
          [](array& a, const ScalarOrArray v) {
 | 
					          [](array& a, const ScalarOrArray v) -> array& {
 | 
				
			||||||
            auto b = to_array(v, a.dtype());
 | 
					            auto b = to_array(v, a.dtype());
 | 
				
			||||||
            if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
 | 
					            if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
 | 
				
			||||||
              throw std::invalid_argument(
 | 
					              throw std::invalid_argument(
 | 
				
			||||||
@@ -1065,7 +1065,7 @@ void init_array(py::module_& m) {
 | 
				
			|||||||
          "other"_a)
 | 
					          "other"_a)
 | 
				
			||||||
      .def(
 | 
					      .def(
 | 
				
			||||||
          "__ior__",
 | 
					          "__ior__",
 | 
				
			||||||
          [](array& a, const ScalarOrArray v) {
 | 
					          [](array& a, const ScalarOrArray v) -> array& {
 | 
				
			||||||
            auto b = to_array(v, a.dtype());
 | 
					            auto b = to_array(v, a.dtype());
 | 
				
			||||||
            if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
 | 
					            if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
 | 
				
			||||||
              throw std::invalid_argument(
 | 
					              throw std::invalid_argument(
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1442,6 +1442,21 @@ class TestArray(mlx_tests.MLXTestCase):
 | 
				
			|||||||
        b @= a
 | 
					        b @= a
 | 
				
			||||||
        self.assertTrue(mx.array_equal(a, b))
 | 
					        self.assertTrue(mx.array_equal(a, b))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_inplace_preserves_ids(self):
 | 
				
			||||||
 | 
					        a = mx.array([1.0])
 | 
				
			||||||
 | 
					        orig_id = id(a)
 | 
				
			||||||
 | 
					        a += mx.array(2.0)
 | 
				
			||||||
 | 
					        self.assertEqual(id(a), orig_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        a[0] = 2.0
 | 
				
			||||||
 | 
					        self.assertEqual(id(a), orig_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        a -= mx.array(3.0)
 | 
				
			||||||
 | 
					        self.assertEqual(id(a), orig_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        a *= mx.array(3.0)
 | 
				
			||||||
 | 
					        self.assertEqual(id(a), orig_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_load_from_pickled_np(self):
 | 
					    def test_load_from_pickled_np(self):
 | 
				
			||||||
        a = np.array([1, 2, 3], dtype=np.int32)
 | 
					        a = np.array([1, 2, 3], dtype=np.int32)
 | 
				
			||||||
        b = pickle.loads(pickle.dumps(a))
 | 
					        b = pickle.loads(pickle.dumps(a))
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user