xor op on arrays (#1875)

This commit is contained in:
Alex Barron
2025-02-17 00:24:53 -08:00
committed by GitHub
parent 5274c3c43f
commit 4c1dfa58b7
2 changed files with 38 additions and 0 deletions

View File

@@ -878,6 +878,38 @@ void init_array(nb::module_& m) {
},
"other"_a,
nb::rv_policy::none)
.def(
"__xor__",
[](const mx::array& a, const ScalarOrArray v) {
if (!is_comparable_with_array(v)) {
throw_invalid_operation("bitwise xor", v);
}
auto b = to_array(v, a.dtype());
if (mx::issubdtype(a.dtype(), mx::inexact) ||
mx::issubdtype(b.dtype(), mx::inexact)) {
throw std::invalid_argument(
"Floating point types not allowed with bitwise xor.");
}
return mx::bitwise_xor(a, b);
},
"other"_a)
.def(
"__ixor__",
[](mx::array& a, const ScalarOrArray v) -> mx::array& {
if (!is_comparable_with_array(v)) {
throw_invalid_operation("inplace bitwise xor", v);
}
auto b = to_array(v, a.dtype());
if (mx::issubdtype(a.dtype(), mx::inexact) ||
mx::issubdtype(b.dtype(), mx::inexact)) {
throw std::invalid_argument(
"Floating point types not allowed bitwise xor.");
}
a.overwrite_descriptor(mx::bitwise_xor(a, b));
return a;
},
"other"_a,
nb::rv_policy::none)
.def("__int__", [](mx::array& a) { return nb::int_(to_scalar(a)); })
.def("__float__", [](mx::array& a) { return nb::float_(to_scalar(a)); })
.def(