mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-17 17:28:10 +08:00
Fix divide types + floor divide (//) (#138)
* divide types * fix black + test
This commit is contained in:
@@ -623,25 +623,41 @@ void init_array(py::module_& m) {
|
||||
.def(
|
||||
"__truediv__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
return divide(a, to_array(v, float32));
|
||||
return divide(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__div__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
return divide(a, to_array(v, float32));
|
||||
return divide(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__floordiv__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
auto b = to_array(v, a.dtype());
|
||||
auto t = promote_types(a.dtype(), b.dtype());
|
||||
return astype(divide(a, b), t);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__rtruediv__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
return divide(to_array(v, float32), a);
|
||||
return divide(to_array(v, a.dtype()), a);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__rfloordiv__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
auto b = to_array(v, a.dtype());
|
||||
auto t = promote_types(a.dtype(), b.dtype());
|
||||
return astype(divide(b, a), t);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__rdiv__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
return divide(to_array(v, float32), a);
|
||||
return divide(to_array(v, a.dtype()), a);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
|
Reference in New Issue
Block a user