consistently handle all -inf in softmax (#1470)

This commit is contained in:
Awni Hannun 2024-10-08 09:54:02 -07:00 committed by GitHub
parent 3274c6a087
commit 1fa0d20a30
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 13 additions and 6 deletions

View File

@ -33,8 +33,8 @@ namespace {
* Note: The implementation below is a general fast exp. There could be faster * Note: The implementation below is a general fast exp. There could be faster
* implementations for numbers strictly < 0. * implementations for numbers strictly < 0.
*/ */
inline simd_float16 simd_fast_exp(simd_float16 x) { inline simd_float16 simd_fast_exp(simd_float16 x_init) {
x *= 1.442695; // multiply with log_2(e) auto x = x_init * 1.442695; // multiply with log_2(e)
simd_float16 ipart, fpart; simd_float16 ipart, fpart;
simd_int16 epart; simd_int16 epart;
x = simd_clamp(x, -80, 80); x = simd_clamp(x, -80, 80);
@ -53,7 +53,9 @@ inline simd_float16 simd_fast_exp(simd_float16 x) {
// bitshifting // bitshifting
epart = (simd_int(ipart) + 127) << 23; epart = (simd_int(ipart) + 127) << 23;
return (*(simd_float16*)&epart) * x; // Avoid supressing NaNs
simd_int16 eq = (x_init == x_init);
return simd_bitselect(x_init, (*(simd_float16*)&epart) * x, eq);
} }
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC

View File

@ -32,12 +32,12 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
} }
} else { } else {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
ld[i] = ((lid * N_READS + i) < axis_size) ? AccT(in[i]) ld[i] =
: Limits<AccT>::finite_min; ((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits<AccT>::min;
} }
} }
if (simd_group_id == 0) { if (simd_group_id == 0) {
local_max[simd_lane_id] = Limits<AccT>::finite_min; local_max[simd_lane_id] = Limits<AccT>::min;
local_normalizer[simd_lane_id] = 0; local_normalizer[simd_lane_id] = 0;
} }
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);

View File

@ -1567,6 +1567,11 @@ class TestOps(mlx_tests.MLXTestCase):
out = mx.softmax(a, axis=-1, precise=True) out = mx.softmax(a, axis=-1, precise=True)
self.assertTrue(mx.allclose(out_expect, out)) self.assertTrue(mx.allclose(out_expect, out))
# All Infs give NaNs
for n in [127, 128, 129]:
x = mx.full((n,), vals=-float("inf"))
self.assertTrue(mx.all(mx.isnan(mx.softmax(x))))
def test_concatenate(self): def test_concatenate(self):
a_npy = np.random.randn(32, 32, 32) a_npy = np.random.randn(32, 32, 32)
b_npy = np.random.randn(32, 32, 32) b_npy = np.random.randn(32, 32, 32)