Faster general unary op (#2472)

* faster general unary op

* faster general ops + reorg

* fix + comment

* binary two

* copy general
This commit is contained in:
Awni Hannun
2025-08-15 15:04:12 -07:00
committed by GitHub
parent dfb5022eab
commit 6441c21a94
62 changed files with 1215 additions and 203 deletions

View File

@@ -0,0 +1,15 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Equal::eval_gpu");
auto& s = out.primitive().stream();
if (equal_nan_) {
binary_op_gpu<cu::NaNEqual>(inputs, out, name(), s);
} else {
binary_op_gpu<cu::Equal>(inputs, out, name(), s);
}
}
} // namespace mlx::core