mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
consistently handle all -inf in softmax (#1470)
This commit is contained in:
@@ -1567,6 +1567,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
out = mx.softmax(a, axis=-1, precise=True)
|
||||
self.assertTrue(mx.allclose(out_expect, out))
|
||||
|
||||
# All Infs give NaNs
|
||||
for n in [127, 128, 129]:
|
||||
x = mx.full((n,), vals=-float("inf"))
|
||||
self.assertTrue(mx.all(mx.isnan(mx.softmax(x))))
|
||||
|
||||
def test_concatenate(self):
|
||||
a_npy = np.random.randn(32, 32, 32)
|
||||
b_npy = np.random.randn(32, 32, 32)
|
||||
|
Reference in New Issue
Block a user