mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
redesign for faster cpu/gpu synch (#1869)
* redesign for faster cpu/gpu synch * load + more async CPU * use command encoder API and move more ops to use it * make fence back-end generic + CPU only fence * faster build * fix async eval * fixes + handle temporaries * fix / improve cpu conv * remove unused status, fix siblings * fix extensions * fix * fix no cpu build * format * comments * fix perf regression, remove unecessary abort * fix events, task limit cpu * fix waiting * fix donation / temporaries in normalization
This commit is contained in:
@@ -16,49 +16,49 @@ namespace mlx::core {
|
||||
namespace {
|
||||
|
||||
template <typename Op>
|
||||
void comparison_op(const array& a, const array& b, array& out, Op op) {
|
||||
void comparison_op(const array& a, const array& b, array& out) {
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool, bool>(a, b, out, op);
|
||||
binary_op<bool, bool, Op>(a, b, out);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t, bool>(a, b, out, op);
|
||||
binary_op<uint8_t, bool, Op>(a, b, out);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t, bool>(a, b, out, op);
|
||||
binary_op<uint16_t, bool, Op>(a, b, out);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t, bool>(a, b, out, op);
|
||||
binary_op<uint32_t, bool, Op>(a, b, out);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t, bool>(a, b, out, op);
|
||||
binary_op<uint64_t, bool, Op>(a, b, out);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t, bool>(a, b, out, op);
|
||||
binary_op<int8_t, bool, Op>(a, b, out);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t, bool>(a, b, out, op);
|
||||
binary_op<int16_t, bool, Op>(a, b, out);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t, bool>(a, b, out, op);
|
||||
binary_op<int32_t, bool, Op>(a, b, out);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t, bool>(a, b, out, op);
|
||||
binary_op<int64_t, bool, Op>(a, b, out);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t, bool>(a, b, out, op);
|
||||
binary_op<float16_t, bool, Op>(a, b, out);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, bool>(a, b, out, op);
|
||||
binary_op<float, bool, Op>(a, b, out);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, bool>(a, b, out, op);
|
||||
binary_op<double, bool, Op>(a, b, out);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, bool>(a, b, out, op);
|
||||
binary_op<bfloat16_t, bool, Op>(a, b, out);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, bool>(a, b, out, op);
|
||||
binary_op<complex64_t, bool, Op>(a, b, out);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -151,47 +151,47 @@ void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (equal_nan_) {
|
||||
switch (a.dtype()) {
|
||||
case float16:
|
||||
binary_op<float16_t, bool>(a, b, out, detail::NaNEqual());
|
||||
binary_op<float16_t, bool, detail::NaNEqual>(a, b, out);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, bool>(a, b, out, detail::NaNEqual());
|
||||
binary_op<float, bool, detail::NaNEqual>(a, b, out);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, bool>(a, b, out, detail::NaNEqual());
|
||||
binary_op<double, bool, detail::NaNEqual>(a, b, out);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, bool>(a, b, out, detail::NaNEqual());
|
||||
binary_op<bfloat16_t, bool, detail::NaNEqual>(a, b, out);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, bool>(a, b, out, detail::NaNEqual());
|
||||
binary_op<complex64_t, bool, detail::NaNEqual>(a, b, out);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[NanEqual::eval_cpu] Only for floating point types.");
|
||||
}
|
||||
} else {
|
||||
comparison_op(a, b, out, detail::Equal());
|
||||
comparison_op<detail::Equal>(a, b, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::Greater());
|
||||
comparison_op<detail::Greater>(inputs[0], inputs[1], out);
|
||||
}
|
||||
|
||||
void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual());
|
||||
comparison_op<detail::GreaterEqual>(inputs[0], inputs[1], out);
|
||||
}
|
||||
|
||||
void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::Less());
|
||||
comparison_op<detail::Less>(inputs[0], inputs[1], out);
|
||||
}
|
||||
|
||||
void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::LessEqual());
|
||||
comparison_op<detail::LessEqual>(inputs[0], inputs[1], out);
|
||||
}
|
||||
|
||||
void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -200,16 +200,16 @@ void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& b = inputs[1];
|
||||
switch (out.dtype()) {
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, out, detail::LogAddExp());
|
||||
binary_op<float16_t, detail::LogAddExp>(a, b, out);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, out, detail::LogAddExp());
|
||||
binary_op<float, detail::LogAddExp>(a, b, out);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double>(a, b, out, detail::LogAddExp());
|
||||
binary_op<double, detail::LogAddExp>(a, b, out);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
|
||||
binary_op<bfloat16_t, detail::LogAddExp>(a, b, out);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
@@ -254,7 +254,7 @@ void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::NotEqual());
|
||||
comparison_op<detail::NotEqual>(inputs[0], inputs[1], out);
|
||||
}
|
||||
|
||||
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
Reference in New Issue
Block a user