Bitwise Inverse (#1862)

* add bitwise inverse

* add vmap + fix nojit

* inverse -> invert

* add to compile + remove unused
This commit is contained in:
Alex Barron
2025-02-13 08:44:14 -08:00
committed by GitHub
parent e425dc00c0
commit 5cd97f7ffe
19 changed files with 147 additions and 8 deletions

View File

@@ -745,11 +745,10 @@ void init_array(nb::module_& m) {
throw std::invalid_argument(
"Floating point types not allowed with bitwise inversion.");
}
if (a.dtype() != mx::bool_) {
throw std::invalid_argument(
"Bitwise inversion not yet supported for integer types.");
if (a.dtype() == mx::bool_) {
return mx::logical_not(a);
}
return mx::logical_not(a);
return mx::bitwise_invert(a);
})
.def(
"__and__",

View File

@@ -4833,6 +4833,28 @@ void init_ops(nb::module_& m) {
Returns:
array: The bitwise right shift ``a >> b``.
)pbdoc");
m.def(
"bitwise_invert",
[](const ScalarOrArray& a_, mx::StreamOrDevice s) {
auto a = to_array(a_);
return mx::bitwise_invert(a, s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def bitwise_invert(a: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Element-wise bitwise inverse.
Take the bitwise complement of the input.
Args:
a (array): Input array or scalar.
Returns:
array: The bitwise inverse ``~a``.
)pbdoc");
m.def(
"view",
[](const ScalarOrArray& a, const mx::Dtype& dtype, mx::StreamOrDevice s) {