fix inplace to not make a shallow copy (#804)

This commit is contained in:
Awni Hannun
2024-03-07 09:34:11 -08:00
committed by GitHub
parent f512b905c7
commit b7588fd5d7
2 changed files with 25 additions and 10 deletions

View File

@@ -802,7 +802,7 @@ void init_array(py::module_& m) {
"other"_a)
.def(
"__iadd__",
[](array& a, const ScalarOrArray v) {
[](array& a, const ScalarOrArray v) -> array& {
a.overwrite_descriptor(add(a, to_array(v, a.dtype())));
return a;
},
@@ -821,7 +821,7 @@ void init_array(py::module_& m) {
"other"_a)
.def(
"__isub__",
[](array& a, const ScalarOrArray v) {
[](array& a, const ScalarOrArray v) -> array& {
a.overwrite_descriptor(subtract(a, to_array(v, a.dtype())));
return a;
},
@@ -840,7 +840,7 @@ void init_array(py::module_& m) {
"other"_a)
.def(
"__imul__",
[](array& a, const ScalarOrArray v) {
[](array& a, const ScalarOrArray v) -> array& {
a.overwrite_descriptor(multiply(a, to_array(v, a.dtype())));
return a;
},
@@ -859,7 +859,7 @@ void init_array(py::module_& m) {
"other"_a)
.def(
"__itruediv__",
[](array& a, const ScalarOrArray v) {
[](array& a, const ScalarOrArray v) -> array& {
if (!is_floating_point(a.dtype())) {
throw std::invalid_argument(
"In place division cannot cast to non-floating point type.");
@@ -894,7 +894,7 @@ void init_array(py::module_& m) {
"other"_a)
.def(
"__ifloordiv__",
[](array& a, const ScalarOrArray v) {
[](array& a, const ScalarOrArray v) -> array& {
a.overwrite_descriptor(floor_divide(a, to_array(v, a.dtype())));
return a;
},
@@ -914,7 +914,7 @@ void init_array(py::module_& m) {
"other"_a)
.def(
"__imod__",
[](array& a, const ScalarOrArray v) {
[](array& a, const ScalarOrArray v) -> array& {
a.overwrite_descriptor(remainder(a, to_array(v, a.dtype())));
return a;
},
@@ -980,7 +980,7 @@ void init_array(py::module_& m) {
"other"_a)
.def(
"__imatmul__",
[](array& a, array& other) {
[](array& a, array& other) -> array& {
a.overwrite_descriptor(matmul(a, other));
return a;
},
@@ -999,7 +999,7 @@ void init_array(py::module_& m) {
"other"_a)
.def(
"__ipow__",
[](array& a, const ScalarOrArray v) {
[](array& a, const ScalarOrArray v) -> array& {
a.overwrite_descriptor(power(a, to_array(v, a.dtype())));
return a;
},
@@ -1034,7 +1034,7 @@ void init_array(py::module_& m) {
"other"_a)
.def(
"__iand__",
[](array& a, const ScalarOrArray v) {
[](array& a, const ScalarOrArray v) -> array& {
auto b = to_array(v, a.dtype());
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
throw std::invalid_argument(
@@ -1065,7 +1065,7 @@ void init_array(py::module_& m) {
"other"_a)
.def(
"__ior__",
[](array& a, const ScalarOrArray v) {
[](array& a, const ScalarOrArray v) -> array& {
auto b = to_array(v, a.dtype());
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
throw std::invalid_argument(