mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Fix out-of-bounds default value in logsumexp/softmax (#2213)
This commit is contained in:
parent
7774b87cbd
commit
79071bfba4
@ -103,8 +103,8 @@ template <typename T, typename AccT = float, int N_READS = 4>
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
|
vals[i] =
|
||||||
: Limits<AccT>::finite_min;
|
(offset + i < axis_size) ? AccT(in[offset + i]) : Limits<AccT>::min;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
prevmax = maxval;
|
prevmax = maxval;
|
||||||
|
@ -128,8 +128,8 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
|
vals[i] =
|
||||||
: Limits<AccT>::finite_min;
|
(offset + i < axis_size) ? AccT(in[offset + i]) : Limits<AccT>::min;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
prevmax = maxval;
|
prevmax = maxval;
|
||||||
|
@ -1036,6 +1036,9 @@ TEST_CASE("test reduction ops") {
|
|||||||
x = array({-inf, -inf});
|
x = array({-inf, -inf});
|
||||||
CHECK_EQ(logsumexp(x).item<float>(), -inf);
|
CHECK_EQ(logsumexp(x).item<float>(), -inf);
|
||||||
|
|
||||||
|
x = repeat(array(-inf), 5000);
|
||||||
|
CHECK_EQ(logsumexp(x).item<float>(), -inf);
|
||||||
|
|
||||||
x = array({0.0f, -inf});
|
x = array({0.0f, -inf});
|
||||||
CHECK_EQ(logsumexp(x).item<float>(), 0.0f);
|
CHECK_EQ(logsumexp(x).item<float>(), 0.0f);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user