mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
[CUDA] Implement Scan kernel (#2347)
* Contiguous scan * Strided scan * Enable tests * Fix failing logaddexp test * Use cexpf in Metal
This commit is contained in:
@@ -1350,6 +1350,11 @@ TEST_CASE("test arithmetic unary ops") {
|
||||
x = split(array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}), 2, 1)[0];
|
||||
auto expected = array({std::exp(0.0f), std::exp(2.0f)}, {2, 1});
|
||||
CHECK(allclose(exp(x), expected).item<bool>());
|
||||
|
||||
// Complex of -inf
|
||||
constexpr float inf = std::numeric_limits<float>::infinity();
|
||||
x = array(complex64_t{-inf, -inf});
|
||||
CHECK_EQ(exp(x).item<complex64_t>(), complex64_t{0, 0});
|
||||
}
|
||||
|
||||
// Test expm1
|
||||
@@ -1830,6 +1835,10 @@ TEST_CASE("test arithmetic binary ops") {
|
||||
x = array(-inf);
|
||||
y = array(inf);
|
||||
CHECK_EQ(logaddexp(x, y).item<float>(), inf);
|
||||
|
||||
x = array(complex64_t{1, 1});
|
||||
y = array(complex64_t{-inf, -inf});
|
||||
CHECK_EQ(logaddexp(x, y).item<complex64_t>(), complex64_t{1, 1});
|
||||
}
|
||||
|
||||
TEST_CASE("test broadcast") {
|
||||
|
Reference in New Issue
Block a user