mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Bitwise Inverse (#1862)
* add bitwise inverse * add vmap + fix nojit * inverse -> invert * add to compile + remove unused
This commit is contained in:
parent
e425dc00c0
commit
5cd97f7ffe
@ -32,6 +32,7 @@ Operations
|
||||
atleast_2d
|
||||
atleast_3d
|
||||
bitwise_and
|
||||
bitwise_invert
|
||||
bitwise_or
|
||||
bitwise_xor
|
||||
block_masked_mm
|
||||
|
@ -137,6 +137,11 @@ Simd<T, N> operator-(Simd<T, N> v) {
|
||||
return -v.value;
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> operator~(Simd<T, N> v) {
|
||||
return ~v.value;
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<bool, N> isnan(Simd<T, N> v) {
|
||||
return asd::convert<char>(v.value != v.value);
|
||||
|
@ -95,6 +95,11 @@ DEFAULT_UNARY(sqrt, std::sqrt)
|
||||
DEFAULT_UNARY(tan, std::tan)
|
||||
DEFAULT_UNARY(tanh, std::tanh)
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> operator~(Simd<T, 1> in) {
|
||||
return ~in.value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
auto real(Simd<T, 1> in) -> Simd<decltype(std::real(in.value)), 1> {
|
||||
return std::real(in.value);
|
||||
|
@ -85,6 +85,12 @@ void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_fp(in, out, detail::ArcTanh());
|
||||
}
|
||||
|
||||
void BitwiseInvert::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_int(in, out, detail::BitwiseInvert());
|
||||
}
|
||||
|
||||
void Ceil::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
@ -141,4 +141,38 @@ void unary_fp(const array& a, array& out, Op op) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void unary_int(const array& a, array& out, Op op) {
|
||||
switch (out.dtype()) {
|
||||
case uint8:
|
||||
unary_op<uint8_t>(a, out, op);
|
||||
break;
|
||||
case uint16:
|
||||
unary_op<uint16_t>(a, out, op);
|
||||
break;
|
||||
case uint32:
|
||||
unary_op<uint32_t>(a, out, op);
|
||||
break;
|
||||
case uint64:
|
||||
unary_op<uint64_t>(a, out, op);
|
||||
break;
|
||||
case int8:
|
||||
unary_op<int8_t>(a, out, op);
|
||||
break;
|
||||
case int16:
|
||||
unary_op<int16_t>(a, out, op);
|
||||
break;
|
||||
case int32:
|
||||
unary_op<int32_t>(a, out, op);
|
||||
break;
|
||||
case int64:
|
||||
unary_op<int64_t>(a, out, op);
|
||||
break;
|
||||
default:
|
||||
std::ostringstream err;
|
||||
err << "[unary_int] Does not support " << out.dtype();
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -34,6 +34,7 @@ DEFAULT_OP(ArcSin, asin)
|
||||
DEFAULT_OP(ArcSinh, asinh)
|
||||
DEFAULT_OP(ArcTan, atan)
|
||||
DEFAULT_OP(ArcTanh, atanh)
|
||||
DEFAULT_OP(BitwiseInvert, operator~)
|
||||
DEFAULT_OP(Ceil, ceil)
|
||||
DEFAULT_OP(Conjugate, conj)
|
||||
DEFAULT_OP(Cos, cos)
|
||||
|
@ -21,8 +21,7 @@
|
||||
instantiate_unary_all_same(op, float32, float) \
|
||||
instantiate_unary_all_same(op, bfloat16, bfloat16_t)
|
||||
|
||||
#define instantiate_unary_types(op) \
|
||||
instantiate_unary_all_same(op, bool_, bool) \
|
||||
#define instantiate_unary_int(op) \
|
||||
instantiate_unary_all_same(op, uint8, uint8_t) \
|
||||
instantiate_unary_all_same(op, uint16, uint16_t) \
|
||||
instantiate_unary_all_same(op, uint32, uint32_t) \
|
||||
@ -30,7 +29,11 @@
|
||||
instantiate_unary_all_same(op, int8, int8_t) \
|
||||
instantiate_unary_all_same(op, int16, int16_t) \
|
||||
instantiate_unary_all_same(op, int32, int32_t) \
|
||||
instantiate_unary_all_same(op, int64, int64_t) \
|
||||
instantiate_unary_all_same(op, int64, int64_t)
|
||||
|
||||
#define instantiate_unary_types(op) \
|
||||
instantiate_unary_all_same(op, bool_, bool) \
|
||||
instantiate_unary_int(op) \
|
||||
instantiate_unary_float(op)
|
||||
|
||||
instantiate_unary_types(Abs)
|
||||
@ -63,6 +66,7 @@ instantiate_unary_float(Rsqrt)
|
||||
instantiate_unary_float(Tan)
|
||||
instantiate_unary_float(Tanh)
|
||||
instantiate_unary_float(Round)
|
||||
instantiate_unary_int(BitwiseInvert)
|
||||
|
||||
instantiate_unary_all_same(Abs, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Conjugate, complex64, complex64_t)
|
||||
|
@ -85,6 +85,13 @@ struct ArcTanh {
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseInvert {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return ~x;
|
||||
};
|
||||
};
|
||||
|
||||
struct Ceil {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
|
@ -124,6 +124,7 @@ UNARY_GPU(ArcSin)
|
||||
UNARY_GPU(ArcSinh)
|
||||
UNARY_GPU(ArcTan)
|
||||
UNARY_GPU(ArcTanh)
|
||||
UNARY_GPU(BitwiseInvert)
|
||||
UNARY_GPU(Conjugate)
|
||||
UNARY_GPU(Cos)
|
||||
UNARY_GPU(Cosh)
|
||||
|
@ -33,6 +33,7 @@ NO_CPU(ArgSort)
|
||||
NO_CPU(AsType)
|
||||
NO_CPU(AsStrided)
|
||||
NO_CPU(BitwiseBinary)
|
||||
NO_CPU(BitwiseInvert)
|
||||
NO_CPU(BlockMaskedMM)
|
||||
NO_CPU(Broadcast)
|
||||
NO_CPU(BroadcastAxes)
|
||||
|
@ -34,6 +34,7 @@ NO_GPU(ArgSort)
|
||||
NO_GPU(AsType)
|
||||
NO_GPU(AsStrided)
|
||||
NO_GPU(BitwiseBinary)
|
||||
NO_GPU(BitwiseInvert)
|
||||
NO_GPU(BlockMaskedMM)
|
||||
NO_GPU(Broadcast)
|
||||
NO_GPU(BroadcastAxes)
|
||||
|
@ -34,7 +34,7 @@ bool is_unary(const Primitive& p) {
|
||||
typeid(p) == typeid(Square) || typeid(p) == typeid(Sqrt) ||
|
||||
typeid(p) == typeid(Tan) || typeid(p) == typeid(Tanh) ||
|
||||
typeid(p) == typeid(Expm1) || typeid(p) == typeid(Real) ||
|
||||
typeid(p) == typeid(Imag));
|
||||
typeid(p) == typeid(Imag) || typeid(p) == typeid(BitwiseInvert));
|
||||
}
|
||||
|
||||
bool is_binary(const Primitive& p) {
|
||||
|
15
mlx/ops.cpp
15
mlx/ops.cpp
@ -4853,6 +4853,21 @@ array operator>>(const array& a, const array& b) {
|
||||
return right_shift(a, b);
|
||||
}
|
||||
|
||||
array bitwise_invert(const array& a, StreamOrDevice s /* = {} */) {
|
||||
if (issubdtype(a.dtype(), inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"[bitwise_invert] Bitwise inverse only allowed on integer types.");
|
||||
} else if (a.dtype() == bool_) {
|
||||
return logical_not(a, s);
|
||||
}
|
||||
return array(
|
||||
a.shape(), a.dtype(), std::make_shared<BitwiseInvert>(to_stream(s)), {a});
|
||||
}
|
||||
|
||||
array operator~(const array& a) {
|
||||
return bitwise_invert(a);
|
||||
}
|
||||
|
||||
array view(const array& a, const Dtype& dtype, StreamOrDevice s /* = {} */) {
|
||||
if (a.dtype() == dtype) {
|
||||
return a;
|
||||
|
@ -1476,6 +1476,10 @@ array operator<<(const array& a, const array& b);
|
||||
array right_shift(const array& a, const array& b, StreamOrDevice s = {});
|
||||
array operator>>(const array& a, const array& b);
|
||||
|
||||
/** Invert the bits. */
|
||||
array bitwise_invert(const array& a, StreamOrDevice s = {});
|
||||
array operator~(const array& a);
|
||||
|
||||
array view(const array& a, const Dtype& dtype, StreamOrDevice s = {});
|
||||
|
||||
/** Roll elements along an axis and introduce them on the other side */
|
||||
|
@ -4522,6 +4522,14 @@ std::pair<std::vector<array>, std::vector<int>> Tanh::vmap(
|
||||
return {{tanh(inputs[0], stream())}, axes};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> BitwiseInvert::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(axes.size() == 1);
|
||||
return {{bitwise_invert(inputs[0], stream())}, axes};
|
||||
}
|
||||
|
||||
std::vector<array> BlockMaskedMM::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
|
@ -459,6 +459,19 @@ class BitwiseBinary : public UnaryPrimitive {
|
||||
Op op_;
|
||||
};
|
||||
|
||||
class BitwiseInvert : public UnaryPrimitive {
|
||||
public:
|
||||
explicit BitwiseInvert(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_PRINT(BitwiseInvert)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
};
|
||||
|
||||
class BlockMaskedMM : public UnaryPrimitive {
|
||||
public:
|
||||
explicit BlockMaskedMM(Stream stream, int block_size)
|
||||
|
@ -745,11 +745,10 @@ void init_array(nb::module_& m) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with bitwise inversion.");
|
||||
}
|
||||
if (a.dtype() != mx::bool_) {
|
||||
throw std::invalid_argument(
|
||||
"Bitwise inversion not yet supported for integer types.");
|
||||
if (a.dtype() == mx::bool_) {
|
||||
return mx::logical_not(a);
|
||||
}
|
||||
return mx::logical_not(a);
|
||||
return mx::bitwise_invert(a);
|
||||
})
|
||||
.def(
|
||||
"__and__",
|
||||
|
@ -4833,6 +4833,28 @@ void init_ops(nb::module_& m) {
|
||||
Returns:
|
||||
array: The bitwise right shift ``a >> b``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"bitwise_invert",
|
||||
[](const ScalarOrArray& a_, mx::StreamOrDevice s) {
|
||||
auto a = to_array(a_);
|
||||
return mx::bitwise_invert(a, s);
|
||||
},
|
||||
nb::arg(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def bitwise_invert(a: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Element-wise bitwise inverse.
|
||||
|
||||
Take the bitwise complement of the input.
|
||||
|
||||
Args:
|
||||
a (array): Input array or scalar.
|
||||
|
||||
Returns:
|
||||
array: The bitwise inverse ``~a``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"view",
|
||||
[](const ScalarOrArray& a, const mx::Dtype& dtype, mx::StreamOrDevice s) {
|
||||
|
@ -2573,6 +2573,18 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
out_np = getattr(np, op)(a_np, b_np)
|
||||
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
|
||||
|
||||
for t in types:
|
||||
a_mlx = a.astype(t)
|
||||
a_np = np.array(a_mlx)
|
||||
|
||||
out_mlx = ~a_mlx
|
||||
out_np = ~a_np
|
||||
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
|
||||
|
||||
out_mlx = mx.bitwise_invert(a_mlx)
|
||||
out_np = mx.bitwise_invert(a_np)
|
||||
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
|
||||
|
||||
# Check broadcasting
|
||||
a = mx.ones((3, 1, 5), dtype=mx.bool_)
|
||||
b = mx.zeros((1, 2, 5), dtype=mx.bool_)
|
||||
|
Loading…
Reference in New Issue
Block a user