mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
xor op on arrays (#1875)
This commit is contained in:
parent
5274c3c43f
commit
4c1dfa58b7
@ -878,6 +878,38 @@ void init_array(nb::module_& m) {
|
|||||||
},
|
},
|
||||||
"other"_a,
|
"other"_a,
|
||||||
nb::rv_policy::none)
|
nb::rv_policy::none)
|
||||||
|
.def(
|
||||||
|
"__xor__",
|
||||||
|
[](const mx::array& a, const ScalarOrArray v) {
|
||||||
|
if (!is_comparable_with_array(v)) {
|
||||||
|
throw_invalid_operation("bitwise xor", v);
|
||||||
|
}
|
||||||
|
auto b = to_array(v, a.dtype());
|
||||||
|
if (mx::issubdtype(a.dtype(), mx::inexact) ||
|
||||||
|
mx::issubdtype(b.dtype(), mx::inexact)) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Floating point types not allowed with bitwise xor.");
|
||||||
|
}
|
||||||
|
return mx::bitwise_xor(a, b);
|
||||||
|
},
|
||||||
|
"other"_a)
|
||||||
|
.def(
|
||||||
|
"__ixor__",
|
||||||
|
[](mx::array& a, const ScalarOrArray v) -> mx::array& {
|
||||||
|
if (!is_comparable_with_array(v)) {
|
||||||
|
throw_invalid_operation("inplace bitwise xor", v);
|
||||||
|
}
|
||||||
|
auto b = to_array(v, a.dtype());
|
||||||
|
if (mx::issubdtype(a.dtype(), mx::inexact) ||
|
||||||
|
mx::issubdtype(b.dtype(), mx::inexact)) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Floating point types not allowed bitwise xor.");
|
||||||
|
}
|
||||||
|
a.overwrite_descriptor(mx::bitwise_xor(a, b));
|
||||||
|
return a;
|
||||||
|
},
|
||||||
|
"other"_a,
|
||||||
|
nb::rv_policy::none)
|
||||||
.def("__int__", [](mx::array& a) { return nb::int_(to_scalar(a)); })
|
.def("__int__", [](mx::array& a) { return nb::int_(to_scalar(a)); })
|
||||||
.def("__float__", [](mx::array& a) { return nb::float_(to_scalar(a)); })
|
.def("__float__", [](mx::array& a) { return nb::float_(to_scalar(a)); })
|
||||||
.def(
|
.def(
|
||||||
|
@ -1725,6 +1725,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual((mx.array(True) | False).item(), True)
|
self.assertEqual((mx.array(True) | False).item(), True)
|
||||||
self.assertEqual((mx.array(False) | False).item(), False)
|
self.assertEqual((mx.array(False) | False).item(), False)
|
||||||
self.assertEqual((~mx.array(False)).item(), True)
|
self.assertEqual((~mx.array(False)).item(), True)
|
||||||
|
self.assertEqual((mx.array(False) ^ True).item(), True)
|
||||||
|
|
||||||
def test_inplace(self):
|
def test_inplace(self):
|
||||||
iops = [
|
iops = [
|
||||||
@ -1734,6 +1735,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
"__ifloordiv__",
|
"__ifloordiv__",
|
||||||
"__imod__",
|
"__imod__",
|
||||||
"__ipow__",
|
"__ipow__",
|
||||||
|
"__ixor__",
|
||||||
]
|
]
|
||||||
|
|
||||||
for op in iops:
|
for op in iops:
|
||||||
@ -1773,6 +1775,10 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
b @= a
|
b @= a
|
||||||
self.assertTrue(mx.array_equal(a, b))
|
self.assertTrue(mx.array_equal(a, b))
|
||||||
|
|
||||||
|
a = mx.array(False)
|
||||||
|
a ^= True
|
||||||
|
self.assertEqual(a.item(), True)
|
||||||
|
|
||||||
def test_inplace_preserves_ids(self):
|
def test_inplace_preserves_ids(self):
|
||||||
a = mx.array([1.0])
|
a = mx.array([1.0])
|
||||||
orig_id = id(a)
|
orig_id = id(a)
|
||||||
|
Loading…
Reference in New Issue
Block a user