mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
xor op on arrays (#1875)
This commit is contained in:
parent
5274c3c43f
commit
4c1dfa58b7
@ -878,6 +878,38 @@ void init_array(nb::module_& m) {
|
||||
},
|
||||
"other"_a,
|
||||
nb::rv_policy::none)
|
||||
.def(
|
||||
"__xor__",
|
||||
[](const mx::array& a, const ScalarOrArray v) {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("bitwise xor", v);
|
||||
}
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (mx::issubdtype(a.dtype(), mx::inexact) ||
|
||||
mx::issubdtype(b.dtype(), mx::inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with bitwise xor.");
|
||||
}
|
||||
return mx::bitwise_xor(a, b);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__ixor__",
|
||||
[](mx::array& a, const ScalarOrArray v) -> mx::array& {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("inplace bitwise xor", v);
|
||||
}
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (mx::issubdtype(a.dtype(), mx::inexact) ||
|
||||
mx::issubdtype(b.dtype(), mx::inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed bitwise xor.");
|
||||
}
|
||||
a.overwrite_descriptor(mx::bitwise_xor(a, b));
|
||||
return a;
|
||||
},
|
||||
"other"_a,
|
||||
nb::rv_policy::none)
|
||||
.def("__int__", [](mx::array& a) { return nb::int_(to_scalar(a)); })
|
||||
.def("__float__", [](mx::array& a) { return nb::float_(to_scalar(a)); })
|
||||
.def(
|
||||
|
@ -1725,6 +1725,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
self.assertEqual((mx.array(True) | False).item(), True)
|
||||
self.assertEqual((mx.array(False) | False).item(), False)
|
||||
self.assertEqual((~mx.array(False)).item(), True)
|
||||
self.assertEqual((mx.array(False) ^ True).item(), True)
|
||||
|
||||
def test_inplace(self):
|
||||
iops = [
|
||||
@ -1734,6 +1735,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
"__ifloordiv__",
|
||||
"__imod__",
|
||||
"__ipow__",
|
||||
"__ixor__",
|
||||
]
|
||||
|
||||
for op in iops:
|
||||
@ -1773,6 +1775,10 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
b @= a
|
||||
self.assertTrue(mx.array_equal(a, b))
|
||||
|
||||
a = mx.array(False)
|
||||
a ^= True
|
||||
self.assertEqual(a.item(), True)
|
||||
|
||||
def test_inplace_preserves_ids(self):
|
||||
a = mx.array([1.0])
|
||||
orig_id = id(a)
|
||||
|
Loading…
Reference in New Issue
Block a user