mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-01 08:38:12 +08:00
xor op on arrays (#1875)
This commit is contained in:
@@ -878,6 +878,38 @@ void init_array(nb::module_& m) {
|
||||
},
|
||||
"other"_a,
|
||||
nb::rv_policy::none)
|
||||
.def(
|
||||
"__xor__",
|
||||
[](const mx::array& a, const ScalarOrArray v) {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("bitwise xor", v);
|
||||
}
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (mx::issubdtype(a.dtype(), mx::inexact) ||
|
||||
mx::issubdtype(b.dtype(), mx::inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with bitwise xor.");
|
||||
}
|
||||
return mx::bitwise_xor(a, b);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__ixor__",
|
||||
[](mx::array& a, const ScalarOrArray v) -> mx::array& {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("inplace bitwise xor", v);
|
||||
}
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (mx::issubdtype(a.dtype(), mx::inexact) ||
|
||||
mx::issubdtype(b.dtype(), mx::inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed bitwise xor.");
|
||||
}
|
||||
a.overwrite_descriptor(mx::bitwise_xor(a, b));
|
||||
return a;
|
||||
},
|
||||
"other"_a,
|
||||
nb::rv_policy::none)
|
||||
.def("__int__", [](mx::array& a) { return nb::int_(to_scalar(a)); })
|
||||
.def("__float__", [](mx::array& a) { return nb::float_(to_scalar(a)); })
|
||||
.def(
|
||||
|
||||
Reference in New Issue
Block a user