From 79071bfba4f012517859bcfdd9032123d16cc6b6 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 21 May 2025 23:25:16 +0900 Subject: [PATCH] Fix out-of-bounds default value in logsumexp/softmax (#2213) --- mlx/backend/metal/kernels/logsumexp.h | 4 ++-- mlx/backend/metal/kernels/softmax.h | 4 ++-- tests/ops_tests.cpp | 3 +++ 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/mlx/backend/metal/kernels/logsumexp.h b/mlx/backend/metal/kernels/logsumexp.h index 93744e15d..c746050b3 100644 --- a/mlx/backend/metal/kernels/logsumexp.h +++ b/mlx/backend/metal/kernels/logsumexp.h @@ -103,8 +103,8 @@ template } } else { for (int i = 0; i < N_READS; i++) { - vals[i] = (offset + i < axis_size) ? AccT(in[offset + i]) - : Limits::finite_min; + vals[i] = + (offset + i < axis_size) ? AccT(in[offset + i]) : Limits::min; } } prevmax = maxval; diff --git a/mlx/backend/metal/kernels/softmax.h b/mlx/backend/metal/kernels/softmax.h index b36b73bd8..6ea4ac732 100644 --- a/mlx/backend/metal/kernels/softmax.h +++ b/mlx/backend/metal/kernels/softmax.h @@ -128,8 +128,8 @@ template } } else { for (int i = 0; i < N_READS; i++) { - vals[i] = (offset + i < axis_size) ? AccT(in[offset + i]) - : Limits::finite_min; + vals[i] = + (offset + i < axis_size) ? AccT(in[offset + i]) : Limits::min; } } prevmax = maxval; diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 5e2bae5a0..8833424a6 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1036,6 +1036,9 @@ TEST_CASE("test reduction ops") { x = array({-inf, -inf}); CHECK_EQ(logsumexp(x).item(), -inf); + x = repeat(array(-inf), 5000); + CHECK_EQ(logsumexp(x).item(), -inf); + x = array({0.0f, -inf}); CHECK_EQ(logsumexp(x).item(), 0.0f);