mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
reduce binary size (#1952)
This commit is contained in:
@@ -14,88 +14,57 @@ void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
// No-op for unsigned types
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
auto op = detail::Abs{};
|
||||
switch (out.dtype()) {
|
||||
case int8:
|
||||
unary_op<int8_t>(in, out, op);
|
||||
break;
|
||||
case int16:
|
||||
unary_op<int16_t>(in, out, op);
|
||||
break;
|
||||
case int32:
|
||||
unary_op<int32_t>(in, out, op);
|
||||
break;
|
||||
case int64:
|
||||
unary_op<int64_t>(in, out, op);
|
||||
break;
|
||||
case float16:
|
||||
unary_op<float16_t>(in, out, op);
|
||||
break;
|
||||
case float32:
|
||||
unary_op<float>(in, out, op);
|
||||
break;
|
||||
case float64:
|
||||
unary_op<double>(in, out, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(in, out, op);
|
||||
break;
|
||||
case complex64:
|
||||
unary_op<complex64_t>(in, out, op);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("[Abs] Called on unsigned type");
|
||||
}
|
||||
unary_signed(in, out, detail::Abs(), stream());
|
||||
}
|
||||
}
|
||||
|
||||
void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::ArcCos());
|
||||
unary_fp(in, out, detail::ArcCos(), stream());
|
||||
}
|
||||
|
||||
void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::ArcCosh());
|
||||
unary_fp(in, out, detail::ArcCosh(), stream());
|
||||
}
|
||||
|
||||
void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::ArcSin());
|
||||
unary_fp(in, out, detail::ArcSin(), stream());
|
||||
}
|
||||
|
||||
void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::ArcSinh());
|
||||
unary_fp(in, out, detail::ArcSinh(), stream());
|
||||
}
|
||||
|
||||
void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::ArcTan());
|
||||
unary_fp(in, out, detail::ArcTan(), stream());
|
||||
}
|
||||
|
||||
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::ArcTanh());
|
||||
unary_fp(in, out, detail::ArcTanh(), stream());
|
||||
}
|
||||
|
||||
void BitwiseInvert::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_int(in, out, detail::BitwiseInvert());
|
||||
unary_int(in, out, detail::BitwiseInvert(), stream());
|
||||
}
|
||||
|
||||
void Ceil::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Ceil());
|
||||
unary_fp(in, out, detail::Ceil(), stream());
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
@@ -104,84 +73,50 @@ void Ceil::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
void Conjugate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
unary_op<complex64_t>(inputs[0], out, detail::Conjugate());
|
||||
unary_complex(inputs[0], out, detail::Conjugate(), stream());
|
||||
}
|
||||
|
||||
void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Cos());
|
||||
unary_fp(in, out, detail::Cos(), stream());
|
||||
}
|
||||
|
||||
void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Cosh());
|
||||
unary_fp(in, out, detail::Cosh(), stream());
|
||||
}
|
||||
|
||||
void Erf::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
switch (out.dtype()) {
|
||||
case float32:
|
||||
unary_op<float>(in, out, detail::Erf());
|
||||
break;
|
||||
case float16:
|
||||
unary_op<float16_t>(in, out, detail::Erf());
|
||||
break;
|
||||
case float64:
|
||||
unary_op<double>(in, out, detail::Erf());
|
||||
break;
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(in, out, detail::Erf());
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[erf] Error function only defined for arrays"
|
||||
" with real floating point type.");
|
||||
}
|
||||
unary_real_fp(in, out, detail::Erf(), stream());
|
||||
}
|
||||
|
||||
void ErfInv::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
switch (out.dtype()) {
|
||||
case float32:
|
||||
unary_op<float>(in, out, detail::ErfInv());
|
||||
break;
|
||||
case float16:
|
||||
unary_op<float16_t>(in, out, detail::ErfInv());
|
||||
break;
|
||||
case float64:
|
||||
unary_op<double>(in, out, detail::ErfInv());
|
||||
break;
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(in, out, detail::ErfInv());
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[erf_inv] Inverse error function only defined for arrays"
|
||||
" with real floating point type.");
|
||||
}
|
||||
unary_real_fp(in, out, detail::ErfInv(), stream());
|
||||
}
|
||||
|
||||
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Exp());
|
||||
unary_fp(in, out, detail::Exp(), stream());
|
||||
}
|
||||
|
||||
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Expm1());
|
||||
unary_fp(in, out, detail::Expm1(), stream());
|
||||
}
|
||||
|
||||
void Floor::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Floor());
|
||||
unary_fp(in, out, detail::Floor(), stream());
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
@@ -189,7 +124,7 @@ void Floor::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
void Imag::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op<complex64_t, float>(inputs[0], out, detail::Imag());
|
||||
unary_complex_to_float(inputs[0], out, detail::Imag(), stream());
|
||||
}
|
||||
|
||||
void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -197,13 +132,13 @@ void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
const auto& in = inputs[0];
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
unary_fp(in, out, detail::Log());
|
||||
unary_fp(in, out, detail::Log(), stream());
|
||||
break;
|
||||
case Base::two:
|
||||
unary_fp(in, out, detail::Log2());
|
||||
unary_fp(in, out, detail::Log2(), stream());
|
||||
break;
|
||||
case Base::ten:
|
||||
unary_fp(in, out, detail::Log10());
|
||||
unary_fp(in, out, detail::Log10(), stream());
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -211,30 +146,30 @@ void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Log1p());
|
||||
unary_fp(in, out, detail::Log1p(), stream());
|
||||
}
|
||||
|
||||
void LogicalNot::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
unary(in, out, detail::LogicalNot());
|
||||
unary(in, out, detail::LogicalNot(), stream());
|
||||
}
|
||||
|
||||
void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
unary(in, out, detail::Negative());
|
||||
unary(in, out, detail::Negative(), stream());
|
||||
}
|
||||
|
||||
void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
|
||||
unary_complex_to_float(inputs[0], out, detail::Real(), stream());
|
||||
}
|
||||
|
||||
void Round::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Round());
|
||||
unary_fp(in, out, detail::Round(), stream());
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
@@ -244,7 +179,7 @@ void Round::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
void Sigmoid::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Sigmoid());
|
||||
unary_fp(in, out, detail::Sigmoid(), stream());
|
||||
}
|
||||
|
||||
void Sign::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -253,48 +188,48 @@ void Sign::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (in.dtype() == bool_) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
unary(in, out, detail::Sign());
|
||||
unary(in, out, detail::Sign(), stream());
|
||||
}
|
||||
}
|
||||
|
||||
void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Sin());
|
||||
unary_fp(in, out, detail::Sin(), stream());
|
||||
}
|
||||
|
||||
void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Sinh());
|
||||
unary_fp(in, out, detail::Sinh(), stream());
|
||||
}
|
||||
|
||||
void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
unary(in, out, detail::Square());
|
||||
unary(in, out, detail::Square(), stream());
|
||||
}
|
||||
|
||||
void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (recip_) {
|
||||
unary_fp(in, out, detail::Rsqrt());
|
||||
unary_fp(in, out, detail::Rsqrt(), stream());
|
||||
} else {
|
||||
unary_fp(in, out, detail::Sqrt());
|
||||
unary_fp(in, out, detail::Sqrt(), stream());
|
||||
}
|
||||
}
|
||||
|
||||
void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Tan());
|
||||
unary_fp(in, out, detail::Tan(), stream());
|
||||
}
|
||||
|
||||
void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Tanh());
|
||||
unary_fp(in, out, detail::Tanh(), stream());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user