mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Fix logsumexp edge case (#740)
* fix logsumexp * fix inf constant * also fix power grad * fix ternary dispatch
This commit is contained in:
@@ -791,13 +791,13 @@ TEST_CASE("test reduction ops") {
|
||||
constexpr float inf = std::numeric_limits<float>::infinity();
|
||||
|
||||
x = array({-inf, -inf});
|
||||
WARN_EQ(logsumexp(x).item<float>(), -inf);
|
||||
CHECK_EQ(logsumexp(x).item<float>(), -inf);
|
||||
|
||||
x = array({0.0f, -inf});
|
||||
CHECK_EQ(logsumexp(x).item<float>(), 0.0f);
|
||||
|
||||
x = array({0.0f, inf});
|
||||
WARN_EQ(logsumexp(x).item<float>(), inf);
|
||||
CHECK_EQ(logsumexp(x).item<float>(), inf);
|
||||
|
||||
x = reshape(arange(6, float32), {2, 3});
|
||||
|
||||
@@ -2819,4 +2819,4 @@ TEST_CASE("test atleast_3d") {
|
||||
out = atleast_3d(x);
|
||||
CHECK_EQ(out.ndim(), 3);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{3, 1, 1});
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user