mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Bitwise Inverse (#1862)
* add bitwise inverse * add vmap + fix nojit * inverse -> invert * add to compile + remove unused
This commit is contained in:
@@ -137,6 +137,11 @@ Simd<T, N> operator-(Simd<T, N> v) {
|
||||
return -v.value;
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> operator~(Simd<T, N> v) {
|
||||
return ~v.value;
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<bool, N> isnan(Simd<T, N> v) {
|
||||
return asd::convert<char>(v.value != v.value);
|
||||
|
||||
@@ -95,6 +95,11 @@ DEFAULT_UNARY(sqrt, std::sqrt)
|
||||
DEFAULT_UNARY(tan, std::tan)
|
||||
DEFAULT_UNARY(tanh, std::tanh)
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> operator~(Simd<T, 1> in) {
|
||||
return ~in.value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
auto real(Simd<T, 1> in) -> Simd<decltype(std::real(in.value)), 1> {
|
||||
return std::real(in.value);
|
||||
|
||||
@@ -85,6 +85,12 @@ void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_fp(in, out, detail::ArcTanh());
|
||||
}
|
||||
|
||||
void BitwiseInvert::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_int(in, out, detail::BitwiseInvert());
|
||||
}
|
||||
|
||||
void Ceil::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
@@ -141,4 +141,38 @@ void unary_fp(const array& a, array& out, Op op) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void unary_int(const array& a, array& out, Op op) {
|
||||
switch (out.dtype()) {
|
||||
case uint8:
|
||||
unary_op<uint8_t>(a, out, op);
|
||||
break;
|
||||
case uint16:
|
||||
unary_op<uint16_t>(a, out, op);
|
||||
break;
|
||||
case uint32:
|
||||
unary_op<uint32_t>(a, out, op);
|
||||
break;
|
||||
case uint64:
|
||||
unary_op<uint64_t>(a, out, op);
|
||||
break;
|
||||
case int8:
|
||||
unary_op<int8_t>(a, out, op);
|
||||
break;
|
||||
case int16:
|
||||
unary_op<int16_t>(a, out, op);
|
||||
break;
|
||||
case int32:
|
||||
unary_op<int32_t>(a, out, op);
|
||||
break;
|
||||
case int64:
|
||||
unary_op<int64_t>(a, out, op);
|
||||
break;
|
||||
default:
|
||||
std::ostringstream err;
|
||||
err << "[unary_int] Does not support " << out.dtype();
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -34,6 +34,7 @@ DEFAULT_OP(ArcSin, asin)
|
||||
DEFAULT_OP(ArcSinh, asinh)
|
||||
DEFAULT_OP(ArcTan, atan)
|
||||
DEFAULT_OP(ArcTanh, atanh)
|
||||
DEFAULT_OP(BitwiseInvert, operator~)
|
||||
DEFAULT_OP(Ceil, ceil)
|
||||
DEFAULT_OP(Conjugate, conj)
|
||||
DEFAULT_OP(Cos, cos)
|
||||
|
||||
Reference in New Issue
Block a user