From 4c1dfa58b7c076bae187e249fa49c6a0337a724c Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Mon, 17 Feb 2025 00:24:53 -0800 Subject: [PATCH] xor op on arrays (#1875) --- python/src/array.cpp | 32 ++++++++++++++++++++++++++++++++ python/tests/test_array.py | 6 ++++++ 2 files changed, 38 insertions(+) diff --git a/python/src/array.cpp b/python/src/array.cpp index 514d3b3d9..375f9a3ec 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -878,6 +878,38 @@ void init_array(nb::module_& m) { }, "other"_a, nb::rv_policy::none) + .def( + "__xor__", + [](const mx::array& a, const ScalarOrArray v) { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("bitwise xor", v); + } + auto b = to_array(v, a.dtype()); + if (mx::issubdtype(a.dtype(), mx::inexact) || + mx::issubdtype(b.dtype(), mx::inexact)) { + throw std::invalid_argument( + "Floating point types not allowed with bitwise xor."); + } + return mx::bitwise_xor(a, b); + }, + "other"_a) + .def( + "__ixor__", + [](mx::array& a, const ScalarOrArray v) -> mx::array& { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("inplace bitwise xor", v); + } + auto b = to_array(v, a.dtype()); + if (mx::issubdtype(a.dtype(), mx::inexact) || + mx::issubdtype(b.dtype(), mx::inexact)) { + throw std::invalid_argument( + "Floating point types not allowed bitwise xor."); + } + a.overwrite_descriptor(mx::bitwise_xor(a, b)); + return a; + }, + "other"_a, + nb::rv_policy::none) .def("__int__", [](mx::array& a) { return nb::int_(to_scalar(a)); }) .def("__float__", [](mx::array& a) { return nb::float_(to_scalar(a)); }) .def( diff --git a/python/tests/test_array.py b/python/tests/test_array.py index a5515b87f..b8917b75c 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1725,6 +1725,7 @@ class TestArray(mlx_tests.MLXTestCase): self.assertEqual((mx.array(True) | False).item(), True) self.assertEqual((mx.array(False) | False).item(), False) self.assertEqual((~mx.array(False)).item(), True) + self.assertEqual((mx.array(False) ^ True).item(), True) def test_inplace(self): iops = [ @@ -1734,6 +1735,7 @@ class TestArray(mlx_tests.MLXTestCase): "__ifloordiv__", "__imod__", "__ipow__", + "__ixor__", ] for op in iops: @@ -1773,6 +1775,10 @@ class TestArray(mlx_tests.MLXTestCase): b @= a self.assertTrue(mx.array_equal(a, b)) + a = mx.array(False) + a ^= True + self.assertEqual(a.item(), True) + def test_inplace_preserves_ids(self): a = mx.array([1.0]) orig_id = id(a)