in place ops behave in place, fix some overloads (#411)

This commit is contained in:
Awni Hannun
2024-01-09 16:05:38 -08:00
committed by GitHub
parent 961435a243
commit 1d90a76d63
2 changed files with 202 additions and 19 deletions

View File

@@ -711,6 +711,13 @@ void init_array(py::module_& m) {
return add(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__iadd__",
[](array& a, const ScalarOrArray v) {
a.overwrite_descriptor(add(a, to_array(v, a.dtype())));
return a;
},
"other"_a)
.def(
"__radd__",
[](const array& a, const ScalarOrArray v) {
@@ -723,6 +730,13 @@ void init_array(py::module_& m) {
return subtract(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__isub__",
[](array& a, const ScalarOrArray v) {
a.overwrite_descriptor(subtract(a, to_array(v, a.dtype())));
return a;
},
"other"_a)
.def(
"__rsub__",
[](const array& a, const ScalarOrArray v) {
@@ -735,6 +749,13 @@ void init_array(py::module_& m) {
return multiply(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__imul__",
[](array& a, const ScalarOrArray v) {
a.overwrite_descriptor(multiply(a, to_array(v, a.dtype())));
return a;
},
"other"_a)
.def(
"__rmul__",
[](const array& a, const ScalarOrArray v) {
@@ -748,16 +769,14 @@ void init_array(py::module_& m) {
},
"other"_a)
.def(
"__div__",
[](const array& a, const ScalarOrArray v) {
return divide(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__floordiv__",
[](const array& a, const ScalarOrArray v) {
auto b = to_array(v, a.dtype());
return floor_divide(a, b);
"__itruediv__",
[](array& a, const ScalarOrArray v) {
if (!is_floating_point(a.dtype())) {
throw std::invalid_argument(
"In place division cannot cast to non-floating point type.");
}
a.overwrite_descriptor(divide(a, to_array(v, a.dtype())));
return a;
},
"other"_a)
.def(
@@ -766,6 +785,31 @@ void init_array(py::module_& m) {
return divide(to_array(v, a.dtype()), a);
},
"other"_a)
.def(
"__div__",
[](const array& a, const ScalarOrArray v) {
return divide(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__rdiv__",
[](const array& a, const ScalarOrArray v) {
return divide(to_array(v, a.dtype()), a);
},
"other"_a)
.def(
"__floordiv__",
[](const array& a, const ScalarOrArray v) {
return floor_divide(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__ifloordiv__",
[](array& a, const ScalarOrArray v) {
a.overwrite_descriptor(floor_divide(a, to_array(v, a.dtype())));
return a;
},
"other"_a)
.def(
"__rfloordiv__",
[](const array& a, const ScalarOrArray v) {
@@ -773,18 +817,19 @@ void init_array(py::module_& m) {
return floor_divide(b, a);
},
"other"_a)
.def(
"__rdiv__",
[](const array& a, const ScalarOrArray v) {
return divide(to_array(v, a.dtype()), a);
},
"other"_a)
.def(
"__mod__",
[](const array& a, const ScalarOrArray v) {
return remainder(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__imod__",
[](array& a, const ScalarOrArray v) {
a.overwrite_descriptor(remainder(a, to_array(v, a.dtype())));
return a;
},
"other"_a)
.def(
"__rmod__",
[](const array& a, const ScalarOrArray v) {
@@ -840,23 +885,102 @@ void init_array(py::module_& m) {
return os.str();
})
.def(
"__matmul__", [](array& a, array& other) { return matmul(a, other); })
"__matmul__",
[](const array& a, array& other) { return matmul(a, other); },
"other"_a)
.def(
"__imatmul__",
[](array& a, array& other) {
a.overwrite_descriptor(matmul(a, other));
return a;
},
"other"_a)
.def(
"__pow__",
[](const array& a, const ScalarOrArray v) {
return power(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__ipow__",
[](array& a, const ScalarOrArray v) {
a.overwrite_descriptor(power(a, to_array(v, a.dtype())));
return a;
},
"other"_a)
.def(
"__invert__",
[](const array& a) {
if (is_floating_point(a.dtype())) {
throw std::invalid_argument(
"Floating point types not allowed with or bitwise inversion.");
}
if (a.dtype() != bool_) {
throw std::invalid_argument(
"Bitwise inversion not yet supported for integer types.");
}
return logical_not(a);
})
.def(
"__and__",
[](const array& a, const ScalarOrArray v) {
return logical_and(a, to_array(v, a.dtype()));
auto b = to_array(v, a.dtype());
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
throw std::invalid_argument(
"Floating point types not allowed with bitwise and.");
}
if (a.dtype() != bool_ && b.dtype() != bool_) {
throw std::invalid_argument(
"Bitwise and not yet supported for integer types.");
}
return logical_and(a, b);
},
"other"_a)
.def(
"__iand__",
[](array& a, const ScalarOrArray v) {
auto b = to_array(v, a.dtype());
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
throw std::invalid_argument(
"Floating point types not allowed with bitwise and.");
}
if (a.dtype() != bool_ && b.dtype() != bool_) {
throw std::invalid_argument(
"Bitwise and not yet supported for integer types.");
}
a.overwrite_descriptor(logical_and(a, b));
return a;
},
"other"_a)
.def(
"__or__",
[](const array& a, const ScalarOrArray v) {
return logical_or(a, to_array(v, a.dtype()));
auto b = to_array(v, a.dtype());
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
throw std::invalid_argument(
"Floating point types not allowed with or bitwise or.");
}
if (a.dtype() != bool_ && b.dtype() != bool_) {
throw std::invalid_argument(
"Bitwise or not yet supported for integer types.");
}
return logical_or(a, b);
},
"other"_a)
.def(
"__ior__",
[](array& a, const ScalarOrArray v) {
auto b = to_array(v, a.dtype());
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
throw std::invalid_argument(
"Floating point types not allowed with or bitwise or.");
}
if (a.dtype() != bool_ && b.dtype() != bool_) {
throw std::invalid_argument(
"Bitwise or not yet supported for integer types.");
}
a.overwrite_descriptor(logical_or(a, b));
return a;
},
"other"_a)
.def(