Bitwise Inverse (#1862)

* add bitwise inverse

* add vmap + fix nojit

* inverse -> invert

* add to compile + remove unused
This commit is contained in:
Alex Barron
2025-02-13 08:44:14 -08:00
committed by GitHub
parent e425dc00c0
commit 5cd97f7ffe
19 changed files with 147 additions and 8 deletions

View File

@@ -137,6 +137,11 @@ Simd<T, N> operator-(Simd<T, N> v) {
return -v.value;
}
template <typename T, int N>
Simd<T, N> operator~(Simd<T, N> v) {
return ~v.value;
}
template <typename T, int N>
Simd<bool, N> isnan(Simd<T, N> v) {
return asd::convert<char>(v.value != v.value);

View File

@@ -95,6 +95,11 @@ DEFAULT_UNARY(sqrt, std::sqrt)
DEFAULT_UNARY(tan, std::tan)
DEFAULT_UNARY(tanh, std::tanh)
template <typename T>
Simd<T, 1> operator~(Simd<T, 1> in) {
return ~in.value;
}
template <typename T>
auto real(Simd<T, 1> in) -> Simd<decltype(std::real(in.value)), 1> {
return std::real(in.value);

View File

@@ -85,6 +85,12 @@ void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_fp(in, out, detail::ArcTanh());
}
void BitwiseInvert::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_int(in, out, detail::BitwiseInvert());
}
void Ceil::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];

View File

@@ -141,4 +141,38 @@ void unary_fp(const array& a, array& out, Op op) {
}
}
template <typename Op>
void unary_int(const array& a, array& out, Op op) {
switch (out.dtype()) {
case uint8:
unary_op<uint8_t>(a, out, op);
break;
case uint16:
unary_op<uint16_t>(a, out, op);
break;
case uint32:
unary_op<uint32_t>(a, out, op);
break;
case uint64:
unary_op<uint64_t>(a, out, op);
break;
case int8:
unary_op<int8_t>(a, out, op);
break;
case int16:
unary_op<int16_t>(a, out, op);
break;
case int32:
unary_op<int32_t>(a, out, op);
break;
case int64:
unary_op<int64_t>(a, out, op);
break;
default:
std::ostringstream err;
err << "[unary_int] Does not support " << out.dtype();
throw std::runtime_error(err.str());
}
}
} // namespace mlx::core

View File

@@ -34,6 +34,7 @@ DEFAULT_OP(ArcSin, asin)
DEFAULT_OP(ArcSinh, asinh)
DEFAULT_OP(ArcTan, atan)
DEFAULT_OP(ArcTanh, atanh)
DEFAULT_OP(BitwiseInvert, operator~)
DEFAULT_OP(Ceil, ceil)
DEFAULT_OP(Conjugate, conj)
DEFAULT_OP(Cos, cos)

View File

@@ -21,8 +21,7 @@
instantiate_unary_all_same(op, float32, float) \
instantiate_unary_all_same(op, bfloat16, bfloat16_t)
#define instantiate_unary_types(op) \
instantiate_unary_all_same(op, bool_, bool) \
#define instantiate_unary_int(op) \
instantiate_unary_all_same(op, uint8, uint8_t) \
instantiate_unary_all_same(op, uint16, uint16_t) \
instantiate_unary_all_same(op, uint32, uint32_t) \
@@ -30,7 +29,11 @@
instantiate_unary_all_same(op, int8, int8_t) \
instantiate_unary_all_same(op, int16, int16_t) \
instantiate_unary_all_same(op, int32, int32_t) \
instantiate_unary_all_same(op, int64, int64_t) \
instantiate_unary_all_same(op, int64, int64_t)
#define instantiate_unary_types(op) \
instantiate_unary_all_same(op, bool_, bool) \
instantiate_unary_int(op) \
instantiate_unary_float(op)
instantiate_unary_types(Abs)
@@ -63,6 +66,7 @@ instantiate_unary_float(Rsqrt)
instantiate_unary_float(Tan)
instantiate_unary_float(Tanh)
instantiate_unary_float(Round)
instantiate_unary_int(BitwiseInvert)
instantiate_unary_all_same(Abs, complex64, complex64_t)
instantiate_unary_all_same(Conjugate, complex64, complex64_t)

View File

@@ -85,6 +85,13 @@ struct ArcTanh {
};
};
struct BitwiseInvert {
template <typename T>
T operator()(T x) {
return ~x;
};
};
struct Ceil {
template <typename T>
T operator()(T x) {

View File

@@ -124,6 +124,7 @@ UNARY_GPU(ArcSin)
UNARY_GPU(ArcSinh)
UNARY_GPU(ArcTan)
UNARY_GPU(ArcTanh)
UNARY_GPU(BitwiseInvert)
UNARY_GPU(Conjugate)
UNARY_GPU(Cos)
UNARY_GPU(Cosh)

View File

@@ -33,6 +33,7 @@ NO_CPU(ArgSort)
NO_CPU(AsType)
NO_CPU(AsStrided)
NO_CPU(BitwiseBinary)
NO_CPU(BitwiseInvert)
NO_CPU(BlockMaskedMM)
NO_CPU(Broadcast)
NO_CPU(BroadcastAxes)

View File

@@ -34,6 +34,7 @@ NO_GPU(ArgSort)
NO_GPU(AsType)
NO_GPU(AsStrided)
NO_GPU(BitwiseBinary)
NO_GPU(BitwiseInvert)
NO_GPU(BlockMaskedMM)
NO_GPU(Broadcast)
NO_GPU(BroadcastAxes)