xor op on arrays (#1875)

This commit is contained in:
Alex Barron 2025-02-17 00:24:53 -08:00 committed by GitHub
parent 5274c3c43f
commit 4c1dfa58b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 38 additions and 0 deletions

View File

@ -878,6 +878,38 @@ void init_array(nb::module_& m) {
},
"other"_a,
nb::rv_policy::none)
.def(
"__xor__",
[](const mx::array& a, const ScalarOrArray v) {
if (!is_comparable_with_array(v)) {
throw_invalid_operation("bitwise xor", v);
}
auto b = to_array(v, a.dtype());
if (mx::issubdtype(a.dtype(), mx::inexact) ||
mx::issubdtype(b.dtype(), mx::inexact)) {
throw std::invalid_argument(
"Floating point types not allowed with bitwise xor.");
}
return mx::bitwise_xor(a, b);
},
"other"_a)
.def(
"__ixor__",
[](mx::array& a, const ScalarOrArray v) -> mx::array& {
if (!is_comparable_with_array(v)) {
throw_invalid_operation("inplace bitwise xor", v);
}
auto b = to_array(v, a.dtype());
if (mx::issubdtype(a.dtype(), mx::inexact) ||
mx::issubdtype(b.dtype(), mx::inexact)) {
throw std::invalid_argument(
"Floating point types not allowed bitwise xor.");
}
a.overwrite_descriptor(mx::bitwise_xor(a, b));
return a;
},
"other"_a,
nb::rv_policy::none)
.def("__int__", [](mx::array& a) { return nb::int_(to_scalar(a)); })
.def("__float__", [](mx::array& a) { return nb::float_(to_scalar(a)); })
.def(

View File

@ -1725,6 +1725,7 @@ class TestArray(mlx_tests.MLXTestCase):
self.assertEqual((mx.array(True) | False).item(), True)
self.assertEqual((mx.array(False) | False).item(), False)
self.assertEqual((~mx.array(False)).item(), True)
self.assertEqual((mx.array(False) ^ True).item(), True)
def test_inplace(self):
iops = [
@ -1734,6 +1735,7 @@ class TestArray(mlx_tests.MLXTestCase):
"__ifloordiv__",
"__imod__",
"__ipow__",
"__ixor__",
]
for op in iops:
@ -1773,6 +1775,10 @@ class TestArray(mlx_tests.MLXTestCase):
b @= a
self.assertTrue(mx.array_equal(a, b))
a = mx.array(False)
a ^= True
self.assertEqual(a.item(), True)
def test_inplace_preserves_ids(self):
a = mx.array([1.0])
orig_id = id(a)