mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-25 20:58:13 +08:00
@@ -26,12 +26,14 @@ DEFAULT(ArgReduce)
|
||||
DEFAULT(ArgSort)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT(Equal)
|
||||
DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
DEFAULT(FFT)
|
||||
DEFAULT(Floor)
|
||||
DEFAULT(Gather)
|
||||
DEFAULT(Greater)
|
||||
DEFAULT(GreaterEqual)
|
||||
|
||||
@@ -29,6 +29,7 @@ DEFAULT(ArgSort)
|
||||
DEFAULT(AsType)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Convolution)
|
||||
DEFAULT(Copy)
|
||||
@@ -41,6 +42,7 @@ DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
DEFAULT(Exp)
|
||||
DEFAULT(FFT)
|
||||
DEFAULT(Floor)
|
||||
DEFAULT(Full)
|
||||
DEFAULT(Gather)
|
||||
DEFAULT(Greater)
|
||||
|
||||
@@ -167,6 +167,17 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void Ceil::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (not is_integral(in.dtype())) {
|
||||
unary_fp(in, out, [](auto x) { return std::ceil(x); });
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
void Concatenate::eval(const std::vector<array>& inputs, array& out) {
|
||||
std::vector<int> sizes;
|
||||
sizes.push_back(0);
|
||||
@@ -287,6 +298,17 @@ void Exp::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Floor::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (not is_integral(in.dtype())) {
|
||||
unary_fp(in, out, [](auto x) { return std::floor(x); });
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
void Full::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
@@ -43,6 +43,19 @@ struct ArcTanh {
|
||||
template <typename T> T operator()(T x) { return metal::precise::atanh(x); };
|
||||
};
|
||||
|
||||
struct Ceil {
|
||||
template <typename T> T operator()(T x) { return metal::ceil(x); };
|
||||
template <> int8_t operator()(int8_t x) { return x; };
|
||||
template <> int16_t operator()(int16_t x) { return x; };
|
||||
template <> int32_t operator()(int32_t x) { return x; };
|
||||
template <> int64_t operator()(int64_t x) { return x; };
|
||||
template <> uint8_t operator()(uint8_t x) { return x; };
|
||||
template <> uint16_t operator()(uint16_t x) { return x; };
|
||||
template <> uint32_t operator()(uint32_t x) { return x; };
|
||||
template <> uint64_t operator()(uint64_t x) { return x; };
|
||||
template <> bool operator()(bool x) { return x; };
|
||||
};
|
||||
|
||||
struct Cos {
|
||||
template <typename T> T operator()(T x) { return metal::precise::cos(x); };
|
||||
|
||||
@@ -83,6 +96,19 @@ struct Exp {
|
||||
}
|
||||
};
|
||||
|
||||
struct Floor {
|
||||
template <typename T> T operator()(T x) { return metal::floor(x); };
|
||||
template <> int8_t operator()(int8_t x) { return x; };
|
||||
template <> int16_t operator()(int16_t x) { return x; };
|
||||
template <> int32_t operator()(int32_t x) { return x; };
|
||||
template <> int64_t operator()(int64_t x) { return x; };
|
||||
template <> uint8_t operator()(uint8_t x) { return x; };
|
||||
template <> uint16_t operator()(uint16_t x) { return x; };
|
||||
template <> uint32_t operator()(uint32_t x) { return x; };
|
||||
template <> uint64_t operator()(uint64_t x) { return x; };
|
||||
template <> bool operator()(bool x) { return x; };
|
||||
};
|
||||
|
||||
struct Log {
|
||||
template <typename T> T operator()(T x) { return metal::precise::log(x); };
|
||||
};
|
||||
@@ -253,9 +279,11 @@ instantiate_unary_float(arcsin, ArcSin)
|
||||
instantiate_unary_float(arcsinh, ArcSinh)
|
||||
instantiate_unary_float(arctan, ArcTan)
|
||||
instantiate_unary_float(arctanh, ArcTanh)
|
||||
instantiate_unary_types(ceil, Ceil)
|
||||
instantiate_unary_float(cos, Cos)
|
||||
instantiate_unary_float(cosh, Cosh)
|
||||
instantiate_unary_float(exp, Exp)
|
||||
instantiate_unary_types(floor, Floor)
|
||||
instantiate_unary_float(log, Log)
|
||||
instantiate_unary_float(log2, Log2)
|
||||
instantiate_unary_float(log10, Log10)
|
||||
|
||||
@@ -450,6 +450,14 @@ void Minimum::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "min");
|
||||
}
|
||||
|
||||
void Floor::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "floor");
|
||||
}
|
||||
|
||||
void Ceil::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "ceil");
|
||||
}
|
||||
|
||||
void Multiply::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "mul");
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ NO_GPU(ArgSort)
|
||||
NO_GPU(AsType)
|
||||
NO_GPU(AsStrided)
|
||||
NO_GPU(Broadcast)
|
||||
NO_GPU(Ceil)
|
||||
NO_GPU(Concatenate)
|
||||
NO_GPU(Convolution)
|
||||
NO_GPU(Copy)
|
||||
@@ -36,6 +37,7 @@ NO_GPU(Erf)
|
||||
NO_GPU(ErfInv)
|
||||
NO_GPU(Exp)
|
||||
NO_GPU(FFT)
|
||||
NO_GPU(Floor)
|
||||
NO_GPU(Full)
|
||||
NO_GPU(Gather)
|
||||
NO_GPU(Greater)
|
||||
|
||||
Reference in New Issue
Block a user