mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
@@ -108,105 +108,105 @@ struct Abs {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::abs(x);
|
||||
};
|
||||
}
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcCos {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::acos(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcCosh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::acosh(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcSin {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::asin(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcSinh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::asinh(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcTan {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::atan(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcTan2 {
|
||||
template <typename T>
|
||||
T operator()(T y, T x) {
|
||||
return std::atan2(y, x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcTanh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::atanh(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Ceil {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::ceil(x);
|
||||
};
|
||||
}
|
||||
int8_t operator()(int8_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
int16_t operator()(int16_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
int32_t operator()(int32_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
int64_t operator()(int64_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Conjugate {
|
||||
@@ -219,35 +219,35 @@ struct Cos {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::cos(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Cosh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::cosh(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Erf {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return static_cast<T>(fast_erf(static_cast<float>(x)));
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct ErfInv {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return static_cast<T>(fast_erfinv(static_cast<float>(x)));
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Exp {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return fast_exp(x);
|
||||
};
|
||||
}
|
||||
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return std::exp(x);
|
||||
@@ -258,83 +258,83 @@ struct Expm1 {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return expm1(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Floor {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::floor(x);
|
||||
};
|
||||
}
|
||||
int8_t operator()(int8_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
int16_t operator()(int16_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
int32_t operator()(int32_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
int64_t operator()(int64_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Log {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::log(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Log2 {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::log2(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Log10 {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::log10(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Log1p {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return log1p(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct LogicalNot {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return !x;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Negative {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return -x;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Round {
|
||||
@@ -379,49 +379,49 @@ struct Sin {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::sin(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Sinh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::sinh(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Square {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return x * x;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Sqrt {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::sqrt(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Rsqrt {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return static_cast<decltype(x)>(1.0) / std::sqrt(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Tan {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::tan(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Tanh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::tanh(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Add {
|
||||
@@ -554,7 +554,7 @@ struct LogAddExp {
|
||||
? maxval
|
||||
: static_cast<decltype(x)>(
|
||||
maxval + std::log1p(fast_exp(minval - maxval)));
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Multiply {
|
||||
@@ -602,14 +602,14 @@ struct LogicalAnd {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x && y;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct LogicalOr {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x || y;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Select {
|
||||
@@ -623,35 +623,35 @@ struct BitwiseAnd {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x & y;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct BitwiseOr {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x | y;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct BitwiseXor {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x ^ y;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct LeftShift {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x << y;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct RightShift {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x >> y;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::detail
|
||||
|
||||
Reference in New Issue
Block a user