Bitwise Inverse (#1862)

* add bitwise inverse

* add vmap + fix nojit

* inverse -> invert

* add to compile + remove unused
This commit is contained in:
Alex Barron 2025-02-13 08:44:14 -08:00 committed by GitHub
parent e425dc00c0
commit 5cd97f7ffe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 147 additions and 8 deletions

View File

@ -32,6 +32,7 @@ Operations
atleast_2d
atleast_3d
bitwise_and
bitwise_invert
bitwise_or
bitwise_xor
block_masked_mm

View File

@ -137,6 +137,11 @@ Simd<T, N> operator-(Simd<T, N> v) {
return -v.value;
}
template <typename T, int N>
Simd<T, N> operator~(Simd<T, N> v) {
return ~v.value;
}
template <typename T, int N>
Simd<bool, N> isnan(Simd<T, N> v) {
return asd::convert<char>(v.value != v.value);

View File

@ -95,6 +95,11 @@ DEFAULT_UNARY(sqrt, std::sqrt)
DEFAULT_UNARY(tan, std::tan)
DEFAULT_UNARY(tanh, std::tanh)
template <typename T>
Simd<T, 1> operator~(Simd<T, 1> in) {
return ~in.value;
}
template <typename T>
auto real(Simd<T, 1> in) -> Simd<decltype(std::real(in.value)), 1> {
return std::real(in.value);

View File

@ -85,6 +85,12 @@ void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_fp(in, out, detail::ArcTanh());
}
void BitwiseInvert::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_int(in, out, detail::BitwiseInvert());
}
void Ceil::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];

View File

@ -141,4 +141,38 @@ void unary_fp(const array& a, array& out, Op op) {
}
}
template <typename Op>
void unary_int(const array& a, array& out, Op op) {
switch (out.dtype()) {
case uint8:
unary_op<uint8_t>(a, out, op);
break;
case uint16:
unary_op<uint16_t>(a, out, op);
break;
case uint32:
unary_op<uint32_t>(a, out, op);
break;
case uint64:
unary_op<uint64_t>(a, out, op);
break;
case int8:
unary_op<int8_t>(a, out, op);
break;
case int16:
unary_op<int16_t>(a, out, op);
break;
case int32:
unary_op<int32_t>(a, out, op);
break;
case int64:
unary_op<int64_t>(a, out, op);
break;
default:
std::ostringstream err;
err << "[unary_int] Does not support " << out.dtype();
throw std::runtime_error(err.str());
}
}
} // namespace mlx::core

View File

@ -34,6 +34,7 @@ DEFAULT_OP(ArcSin, asin)
DEFAULT_OP(ArcSinh, asinh)
DEFAULT_OP(ArcTan, atan)
DEFAULT_OP(ArcTanh, atanh)
DEFAULT_OP(BitwiseInvert, operator~)
DEFAULT_OP(Ceil, ceil)
DEFAULT_OP(Conjugate, conj)
DEFAULT_OP(Cos, cos)

View File

@ -21,8 +21,7 @@
instantiate_unary_all_same(op, float32, float) \
instantiate_unary_all_same(op, bfloat16, bfloat16_t)
#define instantiate_unary_types(op) \
instantiate_unary_all_same(op, bool_, bool) \
#define instantiate_unary_int(op) \
instantiate_unary_all_same(op, uint8, uint8_t) \
instantiate_unary_all_same(op, uint16, uint16_t) \
instantiate_unary_all_same(op, uint32, uint32_t) \
@ -30,7 +29,11 @@
instantiate_unary_all_same(op, int8, int8_t) \
instantiate_unary_all_same(op, int16, int16_t) \
instantiate_unary_all_same(op, int32, int32_t) \
instantiate_unary_all_same(op, int64, int64_t) \
instantiate_unary_all_same(op, int64, int64_t)
#define instantiate_unary_types(op) \
instantiate_unary_all_same(op, bool_, bool) \
instantiate_unary_int(op) \
instantiate_unary_float(op)
instantiate_unary_types(Abs)
@ -63,6 +66,7 @@ instantiate_unary_float(Rsqrt)
instantiate_unary_float(Tan)
instantiate_unary_float(Tanh)
instantiate_unary_float(Round)
instantiate_unary_int(BitwiseInvert)
instantiate_unary_all_same(Abs, complex64, complex64_t)
instantiate_unary_all_same(Conjugate, complex64, complex64_t)

View File

@ -85,6 +85,13 @@ struct ArcTanh {
};
};
struct BitwiseInvert {
template <typename T>
T operator()(T x) {
return ~x;
};
};
struct Ceil {
template <typename T>
T operator()(T x) {

View File

@ -124,6 +124,7 @@ UNARY_GPU(ArcSin)
UNARY_GPU(ArcSinh)
UNARY_GPU(ArcTan)
UNARY_GPU(ArcTanh)
UNARY_GPU(BitwiseInvert)
UNARY_GPU(Conjugate)
UNARY_GPU(Cos)
UNARY_GPU(Cosh)

View File

@ -33,6 +33,7 @@ NO_CPU(ArgSort)
NO_CPU(AsType)
NO_CPU(AsStrided)
NO_CPU(BitwiseBinary)
NO_CPU(BitwiseInvert)
NO_CPU(BlockMaskedMM)
NO_CPU(Broadcast)
NO_CPU(BroadcastAxes)

View File

@ -34,6 +34,7 @@ NO_GPU(ArgSort)
NO_GPU(AsType)
NO_GPU(AsStrided)
NO_GPU(BitwiseBinary)
NO_GPU(BitwiseInvert)
NO_GPU(BlockMaskedMM)
NO_GPU(Broadcast)
NO_GPU(BroadcastAxes)

View File

@ -34,7 +34,7 @@ bool is_unary(const Primitive& p) {
typeid(p) == typeid(Square) || typeid(p) == typeid(Sqrt) ||
typeid(p) == typeid(Tan) || typeid(p) == typeid(Tanh) ||
typeid(p) == typeid(Expm1) || typeid(p) == typeid(Real) ||
typeid(p) == typeid(Imag));
typeid(p) == typeid(Imag) || typeid(p) == typeid(BitwiseInvert));
}
bool is_binary(const Primitive& p) {

View File

@ -4853,6 +4853,21 @@ array operator>>(const array& a, const array& b) {
return right_shift(a, b);
}
array bitwise_invert(const array& a, StreamOrDevice s /* = {} */) {
if (issubdtype(a.dtype(), inexact)) {
throw std::invalid_argument(
"[bitwise_invert] Bitwise inverse only allowed on integer types.");
} else if (a.dtype() == bool_) {
return logical_not(a, s);
}
return array(
a.shape(), a.dtype(), std::make_shared<BitwiseInvert>(to_stream(s)), {a});
}
array operator~(const array& a) {
return bitwise_invert(a);
}
array view(const array& a, const Dtype& dtype, StreamOrDevice s /* = {} */) {
if (a.dtype() == dtype) {
return a;

View File

@ -1476,6 +1476,10 @@ array operator<<(const array& a, const array& b);
array right_shift(const array& a, const array& b, StreamOrDevice s = {});
array operator>>(const array& a, const array& b);
/** Invert the bits. */
array bitwise_invert(const array& a, StreamOrDevice s = {});
array operator~(const array& a);
array view(const array& a, const Dtype& dtype, StreamOrDevice s = {});
/** Roll elements along an axis and introduce them on the other side */

View File

@ -4522,6 +4522,14 @@ std::pair<std::vector<array>, std::vector<int>> Tanh::vmap(
return {{tanh(inputs[0], stream())}, axes};
}
std::pair<std::vector<array>, std::vector<int>> BitwiseInvert::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
assert(inputs.size() == 1);
assert(axes.size() == 1);
return {{bitwise_invert(inputs[0], stream())}, axes};
}
std::vector<array> BlockMaskedMM::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,

View File

@ -459,6 +459,19 @@ class BitwiseBinary : public UnaryPrimitive {
Op op_;
};
class BitwiseInvert : public UnaryPrimitive {
public:
explicit BitwiseInvert(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_PRINT(BitwiseInvert)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
};
class BlockMaskedMM : public UnaryPrimitive {
public:
explicit BlockMaskedMM(Stream stream, int block_size)

View File

@ -745,11 +745,10 @@ void init_array(nb::module_& m) {
throw std::invalid_argument(
"Floating point types not allowed with bitwise inversion.");
}
if (a.dtype() != mx::bool_) {
throw std::invalid_argument(
"Bitwise inversion not yet supported for integer types.");
}
if (a.dtype() == mx::bool_) {
return mx::logical_not(a);
}
return mx::bitwise_invert(a);
})
.def(
"__and__",

View File

@ -4833,6 +4833,28 @@ void init_ops(nb::module_& m) {
Returns:
array: The bitwise right shift ``a >> b``.
)pbdoc");
m.def(
"bitwise_invert",
[](const ScalarOrArray& a_, mx::StreamOrDevice s) {
auto a = to_array(a_);
return mx::bitwise_invert(a, s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def bitwise_invert(a: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Element-wise bitwise inverse.
Take the bitwise complement of the input.
Args:
a (array): Input array or scalar.
Returns:
array: The bitwise inverse ``~a``.
)pbdoc");
m.def(
"view",
[](const ScalarOrArray& a, const mx::Dtype& dtype, mx::StreamOrDevice s) {

View File

@ -2573,6 +2573,18 @@ class TestOps(mlx_tests.MLXTestCase):
out_np = getattr(np, op)(a_np, b_np)
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
for t in types:
a_mlx = a.astype(t)
a_np = np.array(a_mlx)
out_mlx = ~a_mlx
out_np = ~a_np
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
out_mlx = mx.bitwise_invert(a_mlx)
out_np = mx.bitwise_invert(a_np)
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
# Check broadcasting
a = mx.ones((3, 1, 5), dtype=mx.bool_)
b = mx.zeros((1, 2, 5), dtype=mx.bool_)