mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-06 12:09:43 +08:00
split logsumexp
This commit is contained in:
@@ -760,6 +760,10 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
x = mx.broadcast_to(mx.random.uniform(shape=(2, 1, 8)), (2, 2, 8))
|
||||
self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))
|
||||
|
||||
# Even larger
|
||||
x = mx.random.uniform(shape=(4 * 4096 + 3,))
|
||||
self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))
|
||||
|
||||
def test_mean(self):
|
||||
x = mx.array(
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user