Fix logsumexp edge case (#740)

* fix logsumexp

* fix inf constant

* also fix power grad

* fix ternary dispatch
This commit is contained in:
Awni Hannun
2024-02-25 08:39:55 -08:00
committed by GitHub
parent ac02cf33bd
commit e6418781ab
12 changed files with 112 additions and 64 deletions

View File

@@ -791,13 +791,13 @@ TEST_CASE("test reduction ops") {
constexpr float inf = std::numeric_limits<float>::infinity();
x = array({-inf, -inf});
WARN_EQ(logsumexp(x).item<float>(), -inf);
CHECK_EQ(logsumexp(x).item<float>(), -inf);
x = array({0.0f, -inf});
CHECK_EQ(logsumexp(x).item<float>(), 0.0f);
x = array({0.0f, inf});
WARN_EQ(logsumexp(x).item<float>(), inf);
CHECK_EQ(logsumexp(x).item<float>(), inf);
x = reshape(arange(6, float32), {2, 3});
@@ -2819,4 +2819,4 @@ TEST_CASE("test atleast_3d") {
out = atleast_3d(x);
CHECK_EQ(out.ndim(), 3);
CHECK_EQ(out.shape(), std::vector<int>{3, 1, 1});
}
}