split logsumexp

This commit is contained in:
Awni Hannun
2025-05-06 17:10:14 -07:00
parent 5a1a5d5ed1
commit 7c99acb799
3 changed files with 42 additions and 10 deletions

View File

@@ -760,6 +760,10 @@ class TestOps(mlx_tests.MLXTestCase):
x = mx.broadcast_to(mx.random.uniform(shape=(2, 1, 8)), (2, 2, 8))
self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))
# Even larger
x = mx.random.uniform(shape=(4 * 4096 + 3,))
self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))
def test_mean(self):
x = mx.array(
[