diff --git a/mlx/backend/metal/kernels/logsumexp.h b/mlx/backend/metal/kernels/logsumexp.h index 93744e15d..c746050b3 100644 --- a/mlx/backend/metal/kernels/logsumexp.h +++ b/mlx/backend/metal/kernels/logsumexp.h @@ -103,8 +103,8 @@ template } } else { for (int i = 0; i < N_READS; i++) { - vals[i] = (offset + i < axis_size) ? AccT(in[offset + i]) - : Limits::finite_min; + vals[i] = + (offset + i < axis_size) ? AccT(in[offset + i]) : Limits::min; } } prevmax = maxval; diff --git a/mlx/backend/metal/kernels/softmax.h b/mlx/backend/metal/kernels/softmax.h index b36b73bd8..6ea4ac732 100644 --- a/mlx/backend/metal/kernels/softmax.h +++ b/mlx/backend/metal/kernels/softmax.h @@ -128,8 +128,8 @@ template } } else { for (int i = 0; i < N_READS; i++) { - vals[i] = (offset + i < axis_size) ? AccT(in[offset + i]) - : Limits::finite_min; + vals[i] = + (offset + i < axis_size) ? AccT(in[offset + i]) : Limits::min; } } prevmax = maxval; diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 5e2bae5a0..8833424a6 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1036,6 +1036,9 @@ TEST_CASE("test reduction ops") { x = array({-inf, -inf}); CHECK_EQ(logsumexp(x).item(), -inf); + x = repeat(array(-inf), 5000); + CHECK_EQ(logsumexp(x).item(), -inf); + x = array({0.0f, -inf}); CHECK_EQ(logsumexp(x).item(), 0.0f);