Fix divide types + floor divide (//) (#138)

* divide types

* fix black + test
This commit is contained in:
Awni Hannun
2023-12-11 20:20:58 -08:00
committed by GitHub
parent 02de234ef0
commit 25f70d4ca4
4 changed files with 42 additions and 6 deletions

View File

@@ -623,25 +623,41 @@ void init_array(py::module_& m) {
.def(
"__truediv__",
[](const array& a, const ScalarOrArray v) {
return divide(a, to_array(v, float32));
return divide(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__div__",
[](const array& a, const ScalarOrArray v) {
return divide(a, to_array(v, float32));
return divide(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__floordiv__",
[](const array& a, const ScalarOrArray v) {
auto b = to_array(v, a.dtype());
auto t = promote_types(a.dtype(), b.dtype());
return astype(divide(a, b), t);
},
"other"_a)
.def(
"__rtruediv__",
[](const array& a, const ScalarOrArray v) {
return divide(to_array(v, float32), a);
return divide(to_array(v, a.dtype()), a);
},
"other"_a)
.def(
"__rfloordiv__",
[](const array& a, const ScalarOrArray v) {
auto b = to_array(v, a.dtype());
auto t = promote_types(a.dtype(), b.dtype());
return astype(divide(b, a), t);
},
"other"_a)
.def(
"__rdiv__",
[](const array& a, const ScalarOrArray v) {
return divide(to_array(v, float32), a);
return divide(to_array(v, a.dtype()), a);
},
"other"_a)
.def(