Custom logsumexp (#2028)

* initial custom logsumexp

* more tests

* comments + fix
This commit is contained in:
Awni Hannun
2025-03-31 07:36:55 -07:00
committed by GitHub
parent ec2854b13a
commit de5f38fd48
27 changed files with 590 additions and 255 deletions

View File

@@ -690,15 +690,34 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertTrue(np.array_equal(b_npy, b_mlx))
def test_logsumexp(self):
def logsumexp(x, axes=None):
maxs = mx.max(x, axis=axes, keepdims=True)
return mx.log(mx.sum(mx.exp(x - maxs), axis=axes, keepdims=True)) + maxs
x = mx.array(
[
[1.0, 2.0],
[3.0, 4.0],
]
)
xnp = np.array(x.tolist(), dtype=np.float32)
expected = np.log(np.sum(np.exp(xnp)))
self.assertTrue(math.isclose(mx.logsumexp(x).item(), expected.item()))
self.assertTrue(math.isclose(mx.logsumexp(x).item(), logsumexp(x).item()))
x = mx.random.uniform(shape=(1025,))
self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))
# Transposed
x = mx.random.uniform(shape=(2, 2, 8))
x = x.swapaxes(0, 1)
self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))
# Broadcast
x = mx.broadcast_to(mx.random.uniform(shape=(2, 1, 8)), (2, 2, 8))
self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))
# Large
x = mx.random.uniform(shape=(1025,))
x = mx.broadcast_to(mx.random.uniform(shape=(2, 1, 8)), (2, 2, 8))
self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))
def test_mean(self):
x = mx.array(
@@ -1643,6 +1662,15 @@ class TestOps(mlx_tests.MLXTestCase):
x = mx.full((n,), vals=-float("inf"))
self.assertTrue(mx.all(mx.isnan(mx.softmax(x))))
# Transposed inputs
a = mx.random.uniform(shape=(32, 32, 32))
b = mx.softmax(a, axis=-1)
c = mx.softmax(a.swapaxes(0, 1), axis=-1).swapaxes(0, 1)
self.assertEqual((b - c).abs().max().item(), 0.0)
with self.assertRaises(ValueError):
mx.softmax(mx.array(1.0), axis=-1)
def test_concatenate(self):
a_npy = np.random.randn(32, 32, 32)
b_npy = np.random.randn(32, 32, 32)