mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	fix inplace to not make a shallow copy (#804)
This commit is contained in:
		@@ -802,7 +802,7 @@ void init_array(py::module_& m) {
 | 
			
		||||
          "other"_a)
 | 
			
		||||
      .def(
 | 
			
		||||
          "__iadd__",
 | 
			
		||||
          [](array& a, const ScalarOrArray v) {
 | 
			
		||||
          [](array& a, const ScalarOrArray v) -> array& {
 | 
			
		||||
            a.overwrite_descriptor(add(a, to_array(v, a.dtype())));
 | 
			
		||||
            return a;
 | 
			
		||||
          },
 | 
			
		||||
@@ -821,7 +821,7 @@ void init_array(py::module_& m) {
 | 
			
		||||
          "other"_a)
 | 
			
		||||
      .def(
 | 
			
		||||
          "__isub__",
 | 
			
		||||
          [](array& a, const ScalarOrArray v) {
 | 
			
		||||
          [](array& a, const ScalarOrArray v) -> array& {
 | 
			
		||||
            a.overwrite_descriptor(subtract(a, to_array(v, a.dtype())));
 | 
			
		||||
            return a;
 | 
			
		||||
          },
 | 
			
		||||
@@ -840,7 +840,7 @@ void init_array(py::module_& m) {
 | 
			
		||||
          "other"_a)
 | 
			
		||||
      .def(
 | 
			
		||||
          "__imul__",
 | 
			
		||||
          [](array& a, const ScalarOrArray v) {
 | 
			
		||||
          [](array& a, const ScalarOrArray v) -> array& {
 | 
			
		||||
            a.overwrite_descriptor(multiply(a, to_array(v, a.dtype())));
 | 
			
		||||
            return a;
 | 
			
		||||
          },
 | 
			
		||||
@@ -859,7 +859,7 @@ void init_array(py::module_& m) {
 | 
			
		||||
          "other"_a)
 | 
			
		||||
      .def(
 | 
			
		||||
          "__itruediv__",
 | 
			
		||||
          [](array& a, const ScalarOrArray v) {
 | 
			
		||||
          [](array& a, const ScalarOrArray v) -> array& {
 | 
			
		||||
            if (!is_floating_point(a.dtype())) {
 | 
			
		||||
              throw std::invalid_argument(
 | 
			
		||||
                  "In place division cannot cast to non-floating point type.");
 | 
			
		||||
@@ -894,7 +894,7 @@ void init_array(py::module_& m) {
 | 
			
		||||
          "other"_a)
 | 
			
		||||
      .def(
 | 
			
		||||
          "__ifloordiv__",
 | 
			
		||||
          [](array& a, const ScalarOrArray v) {
 | 
			
		||||
          [](array& a, const ScalarOrArray v) -> array& {
 | 
			
		||||
            a.overwrite_descriptor(floor_divide(a, to_array(v, a.dtype())));
 | 
			
		||||
            return a;
 | 
			
		||||
          },
 | 
			
		||||
@@ -914,7 +914,7 @@ void init_array(py::module_& m) {
 | 
			
		||||
          "other"_a)
 | 
			
		||||
      .def(
 | 
			
		||||
          "__imod__",
 | 
			
		||||
          [](array& a, const ScalarOrArray v) {
 | 
			
		||||
          [](array& a, const ScalarOrArray v) -> array& {
 | 
			
		||||
            a.overwrite_descriptor(remainder(a, to_array(v, a.dtype())));
 | 
			
		||||
            return a;
 | 
			
		||||
          },
 | 
			
		||||
@@ -980,7 +980,7 @@ void init_array(py::module_& m) {
 | 
			
		||||
          "other"_a)
 | 
			
		||||
      .def(
 | 
			
		||||
          "__imatmul__",
 | 
			
		||||
          [](array& a, array& other) {
 | 
			
		||||
          [](array& a, array& other) -> array& {
 | 
			
		||||
            a.overwrite_descriptor(matmul(a, other));
 | 
			
		||||
            return a;
 | 
			
		||||
          },
 | 
			
		||||
@@ -999,7 +999,7 @@ void init_array(py::module_& m) {
 | 
			
		||||
          "other"_a)
 | 
			
		||||
      .def(
 | 
			
		||||
          "__ipow__",
 | 
			
		||||
          [](array& a, const ScalarOrArray v) {
 | 
			
		||||
          [](array& a, const ScalarOrArray v) -> array& {
 | 
			
		||||
            a.overwrite_descriptor(power(a, to_array(v, a.dtype())));
 | 
			
		||||
            return a;
 | 
			
		||||
          },
 | 
			
		||||
@@ -1034,7 +1034,7 @@ void init_array(py::module_& m) {
 | 
			
		||||
          "other"_a)
 | 
			
		||||
      .def(
 | 
			
		||||
          "__iand__",
 | 
			
		||||
          [](array& a, const ScalarOrArray v) {
 | 
			
		||||
          [](array& a, const ScalarOrArray v) -> array& {
 | 
			
		||||
            auto b = to_array(v, a.dtype());
 | 
			
		||||
            if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
 | 
			
		||||
              throw std::invalid_argument(
 | 
			
		||||
@@ -1065,7 +1065,7 @@ void init_array(py::module_& m) {
 | 
			
		||||
          "other"_a)
 | 
			
		||||
      .def(
 | 
			
		||||
          "__ior__",
 | 
			
		||||
          [](array& a, const ScalarOrArray v) {
 | 
			
		||||
          [](array& a, const ScalarOrArray v) -> array& {
 | 
			
		||||
            auto b = to_array(v, a.dtype());
 | 
			
		||||
            if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
 | 
			
		||||
              throw std::invalid_argument(
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user