mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 09:51:17 +08:00
consistently handle all -inf in softmax (#1470)
This commit is contained in:
parent
3274c6a087
commit
1fa0d20a30
@ -33,8 +33,8 @@ namespace {
|
|||||||
* Note: The implementation below is a general fast exp. There could be faster
|
* Note: The implementation below is a general fast exp. There could be faster
|
||||||
* implementations for numbers strictly < 0.
|
* implementations for numbers strictly < 0.
|
||||||
*/
|
*/
|
||||||
inline simd_float16 simd_fast_exp(simd_float16 x) {
|
inline simd_float16 simd_fast_exp(simd_float16 x_init) {
|
||||||
x *= 1.442695; // multiply with log_2(e)
|
auto x = x_init * 1.442695; // multiply with log_2(e)
|
||||||
simd_float16 ipart, fpart;
|
simd_float16 ipart, fpart;
|
||||||
simd_int16 epart;
|
simd_int16 epart;
|
||||||
x = simd_clamp(x, -80, 80);
|
x = simd_clamp(x, -80, 80);
|
||||||
@ -53,7 +53,9 @@ inline simd_float16 simd_fast_exp(simd_float16 x) {
|
|||||||
// bitshifting
|
// bitshifting
|
||||||
epart = (simd_int(ipart) + 127) << 23;
|
epart = (simd_int(ipart) + 127) << 23;
|
||||||
|
|
||||||
return (*(simd_float16*)&epart) * x;
|
// Avoid supressing NaNs
|
||||||
|
simd_int16 eq = (x_init == x_init);
|
||||||
|
return simd_bitselect(x_init, (*(simd_float16*)&epart) * x, eq);
|
||||||
}
|
}
|
||||||
|
|
||||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||||
|
@ -32,12 +32,12 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
ld[i] = ((lid * N_READS + i) < axis_size) ? AccT(in[i])
|
ld[i] =
|
||||||
: Limits<AccT>::finite_min;
|
((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits<AccT>::min;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (simd_group_id == 0) {
|
if (simd_group_id == 0) {
|
||||||
local_max[simd_lane_id] = Limits<AccT>::finite_min;
|
local_max[simd_lane_id] = Limits<AccT>::min;
|
||||||
local_normalizer[simd_lane_id] = 0;
|
local_normalizer[simd_lane_id] = 0;
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
@ -1567,6 +1567,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
out = mx.softmax(a, axis=-1, precise=True)
|
out = mx.softmax(a, axis=-1, precise=True)
|
||||||
self.assertTrue(mx.allclose(out_expect, out))
|
self.assertTrue(mx.allclose(out_expect, out))
|
||||||
|
|
||||||
|
# All Infs give NaNs
|
||||||
|
for n in [127, 128, 129]:
|
||||||
|
x = mx.full((n,), vals=-float("inf"))
|
||||||
|
self.assertTrue(mx.all(mx.isnan(mx.softmax(x))))
|
||||||
|
|
||||||
def test_concatenate(self):
|
def test_concatenate(self):
|
||||||
a_npy = np.random.randn(32, 32, 32)
|
a_npy = np.random.randn(32, 32, 32)
|
||||||
b_npy = np.random.randn(32, 32, 32)
|
b_npy = np.random.randn(32, 32, 32)
|
||||||
|
Loading…
Reference in New Issue
Block a user