consistently handle all -inf in softmax (#1470)

This commit is contained in:
Awni Hannun
2024-10-08 09:54:02 -07:00
committed by GitHub
parent 3274c6a087
commit 1fa0d20a30
3 changed files with 13 additions and 6 deletions

View File

@@ -1567,6 +1567,11 @@ class TestOps(mlx_tests.MLXTestCase):
out = mx.softmax(a, axis=-1, precise=True)
self.assertTrue(mx.allclose(out_expect, out))
# All Infs give NaNs
for n in [127, 128, 129]:
x = mx.full((n,), vals=-float("inf"))
self.assertTrue(mx.all(mx.isnan(mx.softmax(x))))
def test_concatenate(self):
a_npy = np.random.randn(32, 32, 32)
b_npy = np.random.randn(32, 32, 32)