From fd1d0821d2146e343005c6f17b4282cefde6913e Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sun, 22 Jun 2025 23:34:32 -0700 Subject: [PATCH] Make sure softmax doesn't change the actual max --- mlx/backend/cuda/softmax.cu | 4 ++-- python/tests/cuda_skip.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/mlx/backend/cuda/softmax.cu b/mlx/backend/cuda/softmax.cu index fc001ae75..652e6da19 100644 --- a/mlx/backend/cuda/softmax.cu +++ b/mlx/backend/cuda/softmax.cu @@ -51,7 +51,7 @@ __global__ void softmax(const T* in, T* out, int axis_size) { make_cast_iterator(in), vals, axis_size, - Limits::finite_min()); + Limits::min()); prevmax = maxval; maxval = max_op(maxval, cub::ThreadReduce(vals, max_op)); // Online normalizer calculation for softmax: @@ -79,7 +79,7 @@ __global__ void softmax(const T* in, T* out, int axis_size) { block.sync(); maxval = warp.thread_rank() < warp.meta_group_size() ? local_max[warp.thread_rank()] - : Limits::finite_min(); + : Limits::min(); maxval = cg::reduce(warp, maxval, max_op); normalizer = normalizer * softmax_exp(prevmax - maxval); if (warp.thread_rank() == 0) { diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index bcb95dbb7..89805d017 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -13,7 +13,6 @@ cuda_skip = { "TestLayers.test_upsample", "TestOps.test_complex_ops", "TestOps.test_dynamic_slicing", - "TestOps.test_softmax", "TestReduce.test_axis_permutation_sums", "TestReduce.test_dtypes", "TestReduce.test_expand_sums",