From 5cd97f7ffe8c41b62b887806795b8ebf50073e7e Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Thu, 13 Feb 2025 08:44:14 -0800 Subject: [PATCH] Bitwise Inverse (#1862) * add bitwise inverse * add vmap + fix nojit * inverse -> invert * add to compile + remove unused --- docs/src/python/ops.rst | 1 + mlx/backend/cpu/simd/accelerate_simd.h | 5 ++++ mlx/backend/cpu/simd/base_simd.h | 5 ++++ mlx/backend/cpu/unary.cpp | 6 +++++ mlx/backend/cpu/unary.h | 34 ++++++++++++++++++++++++++ mlx/backend/cpu/unary_ops.h | 1 + mlx/backend/metal/kernels/unary.metal | 10 +++++--- mlx/backend/metal/kernels/unary_ops.h | 7 ++++++ mlx/backend/metal/unary.cpp | 1 + mlx/backend/no_cpu/primitives.cpp | 1 + mlx/backend/no_metal/primitives.cpp | 1 + mlx/compile.cpp | 2 +- mlx/ops.cpp | 15 ++++++++++++ mlx/ops.h | 4 +++ mlx/primitives.cpp | 8 ++++++ mlx/primitives.h | 13 ++++++++++ python/src/array.cpp | 7 +++--- python/src/ops.cpp | 22 +++++++++++++++++ python/tests/test_ops.py | 12 +++++++++ 19 files changed, 147 insertions(+), 8 deletions(-) diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 248028575..c0d098b21 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -32,6 +32,7 @@ Operations atleast_2d atleast_3d bitwise_and + bitwise_invert bitwise_or bitwise_xor block_masked_mm diff --git a/mlx/backend/cpu/simd/accelerate_simd.h b/mlx/backend/cpu/simd/accelerate_simd.h index 59821d03b..a14d99103 100644 --- a/mlx/backend/cpu/simd/accelerate_simd.h +++ b/mlx/backend/cpu/simd/accelerate_simd.h @@ -137,6 +137,11 @@ Simd operator-(Simd v) { return -v.value; } +template +Simd operator~(Simd v) { + return ~v.value; +} + template Simd isnan(Simd v) { return asd::convert(v.value != v.value); diff --git a/mlx/backend/cpu/simd/base_simd.h b/mlx/backend/cpu/simd/base_simd.h index 37fc576a0..bc416fc22 100644 --- a/mlx/backend/cpu/simd/base_simd.h +++ b/mlx/backend/cpu/simd/base_simd.h @@ -95,6 +95,11 @@ DEFAULT_UNARY(sqrt, std::sqrt) DEFAULT_UNARY(tan, std::tan) DEFAULT_UNARY(tanh, std::tanh) +template +Simd operator~(Simd in) { + return ~in.value; +} + template auto real(Simd in) -> Simd { return std::real(in.value); diff --git a/mlx/backend/cpu/unary.cpp b/mlx/backend/cpu/unary.cpp index 0ac4daf27..1b16be395 100644 --- a/mlx/backend/cpu/unary.cpp +++ b/mlx/backend/cpu/unary.cpp @@ -85,6 +85,12 @@ void ArcTanh::eval_cpu(const std::vector& inputs, array& out) { unary_fp(in, out, detail::ArcTanh()); } +void BitwiseInvert::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + unary_int(in, out, detail::BitwiseInvert()); +} + void Ceil::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; diff --git a/mlx/backend/cpu/unary.h b/mlx/backend/cpu/unary.h index edfe2a7b4..d5c30526b 100644 --- a/mlx/backend/cpu/unary.h +++ b/mlx/backend/cpu/unary.h @@ -141,4 +141,38 @@ void unary_fp(const array& a, array& out, Op op) { } } +template +void unary_int(const array& a, array& out, Op op) { + switch (out.dtype()) { + case uint8: + unary_op(a, out, op); + break; + case uint16: + unary_op(a, out, op); + break; + case uint32: + unary_op(a, out, op); + break; + case uint64: + unary_op(a, out, op); + break; + case int8: + unary_op(a, out, op); + break; + case int16: + unary_op(a, out, op); + break; + case int32: + unary_op(a, out, op); + break; + case int64: + unary_op(a, out, op); + break; + default: + std::ostringstream err; + err << "[unary_int] Does not support " << out.dtype(); + throw std::runtime_error(err.str()); + } +} + } // namespace mlx::core diff --git a/mlx/backend/cpu/unary_ops.h b/mlx/backend/cpu/unary_ops.h index 3019ad91e..aab1e0de2 100644 --- a/mlx/backend/cpu/unary_ops.h +++ b/mlx/backend/cpu/unary_ops.h @@ -34,6 +34,7 @@ DEFAULT_OP(ArcSin, asin) DEFAULT_OP(ArcSinh, asinh) DEFAULT_OP(ArcTan, atan) DEFAULT_OP(ArcTanh, atanh) +DEFAULT_OP(BitwiseInvert, operator~) DEFAULT_OP(Ceil, ceil) DEFAULT_OP(Conjugate, conj) DEFAULT_OP(Cos, cos) diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index c64e05931..82692c8e5 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -21,8 +21,7 @@ instantiate_unary_all_same(op, float32, float) \ instantiate_unary_all_same(op, bfloat16, bfloat16_t) -#define instantiate_unary_types(op) \ - instantiate_unary_all_same(op, bool_, bool) \ +#define instantiate_unary_int(op) \ instantiate_unary_all_same(op, uint8, uint8_t) \ instantiate_unary_all_same(op, uint16, uint16_t) \ instantiate_unary_all_same(op, uint32, uint32_t) \ @@ -30,7 +29,11 @@ instantiate_unary_all_same(op, int8, int8_t) \ instantiate_unary_all_same(op, int16, int16_t) \ instantiate_unary_all_same(op, int32, int32_t) \ - instantiate_unary_all_same(op, int64, int64_t) \ + instantiate_unary_all_same(op, int64, int64_t) + +#define instantiate_unary_types(op) \ + instantiate_unary_all_same(op, bool_, bool) \ + instantiate_unary_int(op) \ instantiate_unary_float(op) instantiate_unary_types(Abs) @@ -63,6 +66,7 @@ instantiate_unary_float(Rsqrt) instantiate_unary_float(Tan) instantiate_unary_float(Tanh) instantiate_unary_float(Round) +instantiate_unary_int(BitwiseInvert) instantiate_unary_all_same(Abs, complex64, complex64_t) instantiate_unary_all_same(Conjugate, complex64, complex64_t) diff --git a/mlx/backend/metal/kernels/unary_ops.h b/mlx/backend/metal/kernels/unary_ops.h index ebb5573e4..ceed3efe5 100644 --- a/mlx/backend/metal/kernels/unary_ops.h +++ b/mlx/backend/metal/kernels/unary_ops.h @@ -85,6 +85,13 @@ struct ArcTanh { }; }; +struct BitwiseInvert { + template + T operator()(T x) { + return ~x; + }; +}; + struct Ceil { template T operator()(T x) { diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index 77f82b8fc..2652ecbfa 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -124,6 +124,7 @@ UNARY_GPU(ArcSin) UNARY_GPU(ArcSinh) UNARY_GPU(ArcTan) UNARY_GPU(ArcTanh) +UNARY_GPU(BitwiseInvert) UNARY_GPU(Conjugate) UNARY_GPU(Cos) UNARY_GPU(Cosh) diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index d6d6f91ed..3733af28d 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -33,6 +33,7 @@ NO_CPU(ArgSort) NO_CPU(AsType) NO_CPU(AsStrided) NO_CPU(BitwiseBinary) +NO_CPU(BitwiseInvert) NO_CPU(BlockMaskedMM) NO_CPU(Broadcast) NO_CPU(BroadcastAxes) diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 789a576db..6e37a1d2b 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -34,6 +34,7 @@ NO_GPU(ArgSort) NO_GPU(AsType) NO_GPU(AsStrided) NO_GPU(BitwiseBinary) +NO_GPU(BitwiseInvert) NO_GPU(BlockMaskedMM) NO_GPU(Broadcast) NO_GPU(BroadcastAxes) diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 2c3d4484a..8b7f24c91 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -34,7 +34,7 @@ bool is_unary(const Primitive& p) { typeid(p) == typeid(Square) || typeid(p) == typeid(Sqrt) || typeid(p) == typeid(Tan) || typeid(p) == typeid(Tanh) || typeid(p) == typeid(Expm1) || typeid(p) == typeid(Real) || - typeid(p) == typeid(Imag)); + typeid(p) == typeid(Imag) || typeid(p) == typeid(BitwiseInvert)); } bool is_binary(const Primitive& p) { diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 03785c11d..514a244b0 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4853,6 +4853,21 @@ array operator>>(const array& a, const array& b) { return right_shift(a, b); } +array bitwise_invert(const array& a, StreamOrDevice s /* = {} */) { + if (issubdtype(a.dtype(), inexact)) { + throw std::invalid_argument( + "[bitwise_invert] Bitwise inverse only allowed on integer types."); + } else if (a.dtype() == bool_) { + return logical_not(a, s); + } + return array( + a.shape(), a.dtype(), std::make_shared(to_stream(s)), {a}); +} + +array operator~(const array& a) { + return bitwise_invert(a); +} + array view(const array& a, const Dtype& dtype, StreamOrDevice s /* = {} */) { if (a.dtype() == dtype) { return a; diff --git a/mlx/ops.h b/mlx/ops.h index c0cbc2780..02428b974 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1476,6 +1476,10 @@ array operator<<(const array& a, const array& b); array right_shift(const array& a, const array& b, StreamOrDevice s = {}); array operator>>(const array& a, const array& b); +/** Invert the bits. */ +array bitwise_invert(const array& a, StreamOrDevice s = {}); +array operator~(const array& a); + array view(const array& a, const Dtype& dtype, StreamOrDevice s = {}); /** Roll elements along an axis and introduce them on the other side */ diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 16acfdc9c..ac4c17938 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -4522,6 +4522,14 @@ std::pair, std::vector> Tanh::vmap( return {{tanh(inputs[0], stream())}, axes}; } +std::pair, std::vector> BitwiseInvert::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {{bitwise_invert(inputs[0], stream())}, axes}; +} + std::vector BlockMaskedMM::vjp( const std::vector& primals, const std::vector& cotangents, diff --git a/mlx/primitives.h b/mlx/primitives.h index 19f058fd3..a73ddef96 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -459,6 +459,19 @@ class BitwiseBinary : public UnaryPrimitive { Op op_; }; +class BitwiseInvert : public UnaryPrimitive { + public: + explicit BitwiseInvert(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_PRINT(BitwiseInvert) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + class BlockMaskedMM : public UnaryPrimitive { public: explicit BlockMaskedMM(Stream stream, int block_size) diff --git a/python/src/array.cpp b/python/src/array.cpp index 8f0040b1e..514d3b3d9 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -745,11 +745,10 @@ void init_array(nb::module_& m) { throw std::invalid_argument( "Floating point types not allowed with bitwise inversion."); } - if (a.dtype() != mx::bool_) { - throw std::invalid_argument( - "Bitwise inversion not yet supported for integer types."); + if (a.dtype() == mx::bool_) { + return mx::logical_not(a); } - return mx::logical_not(a); + return mx::bitwise_invert(a); }) .def( "__and__", diff --git a/python/src/ops.cpp b/python/src/ops.cpp index a4f5e96e5..1577cae18 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4833,6 +4833,28 @@ void init_ops(nb::module_& m) { Returns: array: The bitwise right shift ``a >> b``. )pbdoc"); + m.def( + "bitwise_invert", + [](const ScalarOrArray& a_, mx::StreamOrDevice s) { + auto a = to_array(a_); + return mx::bitwise_invert(a, s); + }, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def bitwise_invert(a: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Element-wise bitwise inverse. + + Take the bitwise complement of the input. + + Args: + a (array): Input array or scalar. + + Returns: + array: The bitwise inverse ``~a``. + )pbdoc"); m.def( "view", [](const ScalarOrArray& a, const mx::Dtype& dtype, mx::StreamOrDevice s) { diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index ea798c18c..9c349bb1c 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2573,6 +2573,18 @@ class TestOps(mlx_tests.MLXTestCase): out_np = getattr(np, op)(a_np, b_np) self.assertTrue(np.array_equal(np.array(out_mlx), out_np)) + for t in types: + a_mlx = a.astype(t) + a_np = np.array(a_mlx) + + out_mlx = ~a_mlx + out_np = ~a_np + self.assertTrue(np.array_equal(np.array(out_mlx), out_np)) + + out_mlx = mx.bitwise_invert(a_mlx) + out_np = mx.bitwise_invert(a_np) + self.assertTrue(np.array_equal(np.array(out_mlx), out_np)) + # Check broadcasting a = mx.ones((3, 1, 5), dtype=mx.bool_) b = mx.zeros((1, 2, 5), dtype=mx.bool_)