[CUDA] Implement Scan kernel (#2347)

* Contiguous scan

* Strided scan

* Enable tests

* Fix failing logaddexp test

* Use cexpf in Metal
This commit is contained in:
Cheng
2025-07-11 08:54:12 +09:00
committed by GitHub
parent b6eec20260
commit 8347575ba1
13 changed files with 815 additions and 64 deletions

View File

@@ -1350,6 +1350,11 @@ TEST_CASE("test arithmetic unary ops") {
x = split(array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}), 2, 1)[0];
auto expected = array({std::exp(0.0f), std::exp(2.0f)}, {2, 1});
CHECK(allclose(exp(x), expected).item<bool>());
// Complex of -inf
constexpr float inf = std::numeric_limits<float>::infinity();
x = array(complex64_t{-inf, -inf});
CHECK_EQ(exp(x).item<complex64_t>(), complex64_t{0, 0});
}
// Test expm1
@@ -1830,6 +1835,10 @@ TEST_CASE("test arithmetic binary ops") {
x = array(-inf);
y = array(inf);
CHECK_EQ(logaddexp(x, y).item<float>(), inf);
x = array(complex64_t{1, 1});
y = array(complex64_t{-inf, -inf});
CHECK_EQ(logaddexp(x, y).item<complex64_t>(), complex64_t{1, 1});
}
TEST_CASE("test broadcast") {