mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 16:48:10 +08:00
in place ops behave in place, fix some overloads (#411)
This commit is contained in:
@@ -711,6 +711,13 @@ void init_array(py::module_& m) {
|
||||
return add(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__iadd__",
|
||||
[](array& a, const ScalarOrArray v) {
|
||||
a.overwrite_descriptor(add(a, to_array(v, a.dtype())));
|
||||
return a;
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__radd__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
@@ -723,6 +730,13 @@ void init_array(py::module_& m) {
|
||||
return subtract(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__isub__",
|
||||
[](array& a, const ScalarOrArray v) {
|
||||
a.overwrite_descriptor(subtract(a, to_array(v, a.dtype())));
|
||||
return a;
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__rsub__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
@@ -735,6 +749,13 @@ void init_array(py::module_& m) {
|
||||
return multiply(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__imul__",
|
||||
[](array& a, const ScalarOrArray v) {
|
||||
a.overwrite_descriptor(multiply(a, to_array(v, a.dtype())));
|
||||
return a;
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__rmul__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
@@ -748,16 +769,14 @@ void init_array(py::module_& m) {
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__div__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
return divide(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__floordiv__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
auto b = to_array(v, a.dtype());
|
||||
return floor_divide(a, b);
|
||||
"__itruediv__",
|
||||
[](array& a, const ScalarOrArray v) {
|
||||
if (!is_floating_point(a.dtype())) {
|
||||
throw std::invalid_argument(
|
||||
"In place division cannot cast to non-floating point type.");
|
||||
}
|
||||
a.overwrite_descriptor(divide(a, to_array(v, a.dtype())));
|
||||
return a;
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
@@ -766,6 +785,31 @@ void init_array(py::module_& m) {
|
||||
return divide(to_array(v, a.dtype()), a);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__div__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
return divide(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__rdiv__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
return divide(to_array(v, a.dtype()), a);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__floordiv__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
return floor_divide(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__ifloordiv__",
|
||||
[](array& a, const ScalarOrArray v) {
|
||||
a.overwrite_descriptor(floor_divide(a, to_array(v, a.dtype())));
|
||||
return a;
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__rfloordiv__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
@@ -773,18 +817,19 @@ void init_array(py::module_& m) {
|
||||
return floor_divide(b, a);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__rdiv__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
return divide(to_array(v, a.dtype()), a);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__mod__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
return remainder(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__imod__",
|
||||
[](array& a, const ScalarOrArray v) {
|
||||
a.overwrite_descriptor(remainder(a, to_array(v, a.dtype())));
|
||||
return a;
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__rmod__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
@@ -840,23 +885,102 @@ void init_array(py::module_& m) {
|
||||
return os.str();
|
||||
})
|
||||
.def(
|
||||
"__matmul__", [](array& a, array& other) { return matmul(a, other); })
|
||||
"__matmul__",
|
||||
[](const array& a, array& other) { return matmul(a, other); },
|
||||
"other"_a)
|
||||
.def(
|
||||
"__imatmul__",
|
||||
[](array& a, array& other) {
|
||||
a.overwrite_descriptor(matmul(a, other));
|
||||
return a;
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__pow__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
return power(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__ipow__",
|
||||
[](array& a, const ScalarOrArray v) {
|
||||
a.overwrite_descriptor(power(a, to_array(v, a.dtype())));
|
||||
return a;
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__invert__",
|
||||
[](const array& a) {
|
||||
if (is_floating_point(a.dtype())) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with or bitwise inversion.");
|
||||
}
|
||||
if (a.dtype() != bool_) {
|
||||
throw std::invalid_argument(
|
||||
"Bitwise inversion not yet supported for integer types.");
|
||||
}
|
||||
return logical_not(a);
|
||||
})
|
||||
.def(
|
||||
"__and__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
return logical_and(a, to_array(v, a.dtype()));
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with bitwise and.");
|
||||
}
|
||||
if (a.dtype() != bool_ && b.dtype() != bool_) {
|
||||
throw std::invalid_argument(
|
||||
"Bitwise and not yet supported for integer types.");
|
||||
}
|
||||
return logical_and(a, b);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__iand__",
|
||||
[](array& a, const ScalarOrArray v) {
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with bitwise and.");
|
||||
}
|
||||
if (a.dtype() != bool_ && b.dtype() != bool_) {
|
||||
throw std::invalid_argument(
|
||||
"Bitwise and not yet supported for integer types.");
|
||||
}
|
||||
a.overwrite_descriptor(logical_and(a, b));
|
||||
return a;
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__or__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
return logical_or(a, to_array(v, a.dtype()));
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with or bitwise or.");
|
||||
}
|
||||
if (a.dtype() != bool_ && b.dtype() != bool_) {
|
||||
throw std::invalid_argument(
|
||||
"Bitwise or not yet supported for integer types.");
|
||||
}
|
||||
return logical_or(a, b);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__ior__",
|
||||
[](array& a, const ScalarOrArray v) {
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with or bitwise or.");
|
||||
}
|
||||
if (a.dtype() != bool_ && b.dtype() != bool_) {
|
||||
throw std::invalid_argument(
|
||||
"Bitwise or not yet supported for integer types.");
|
||||
}
|
||||
a.overwrite_descriptor(logical_or(a, b));
|
||||
return a;
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
|
Reference in New Issue
Block a user