mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Refactor CPU compile preamble (#708)
* refactor cpu preamble * fix include order * fix some issues' * fixes for linux * try to fix includes * add back warning suppression * more linux fixes
This commit is contained in:
@@ -11,59 +11,6 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
struct AbsOp {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::abs(x);
|
||||
}
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
}
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
}
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
}
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
}
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
struct SignOp {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return (x > T(0)) - (x < T(0));
|
||||
}
|
||||
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x != 0;
|
||||
}
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x != 0;
|
||||
}
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x != 0;
|
||||
}
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x != 0;
|
||||
}
|
||||
};
|
||||
|
||||
struct RoundOp {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::rint(x);
|
||||
}
|
||||
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {std::rint(x.real()), std::rint(x.imag())};
|
||||
}
|
||||
};
|
||||
|
||||
void set_unary_output_data(const array& in, array& out) {
|
||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||
out.copy_shared_buffer(in);
|
||||
|
||||
Reference in New Issue
Block a user