mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Custom logsumexp (#2028)
* initial custom logsumexp * more tests * comments + fix
This commit is contained in:
@@ -690,15 +690,34 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(np.array_equal(b_npy, b_mlx))
|
||||
|
||||
def test_logsumexp(self):
|
||||
def logsumexp(x, axes=None):
|
||||
maxs = mx.max(x, axis=axes, keepdims=True)
|
||||
return mx.log(mx.sum(mx.exp(x - maxs), axis=axes, keepdims=True)) + maxs
|
||||
|
||||
x = mx.array(
|
||||
[
|
||||
[1.0, 2.0],
|
||||
[3.0, 4.0],
|
||||
]
|
||||
)
|
||||
xnp = np.array(x.tolist(), dtype=np.float32)
|
||||
expected = np.log(np.sum(np.exp(xnp)))
|
||||
self.assertTrue(math.isclose(mx.logsumexp(x).item(), expected.item()))
|
||||
self.assertTrue(math.isclose(mx.logsumexp(x).item(), logsumexp(x).item()))
|
||||
|
||||
x = mx.random.uniform(shape=(1025,))
|
||||
self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))
|
||||
|
||||
# Transposed
|
||||
x = mx.random.uniform(shape=(2, 2, 8))
|
||||
x = x.swapaxes(0, 1)
|
||||
self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))
|
||||
|
||||
# Broadcast
|
||||
x = mx.broadcast_to(mx.random.uniform(shape=(2, 1, 8)), (2, 2, 8))
|
||||
self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))
|
||||
|
||||
# Large
|
||||
x = mx.random.uniform(shape=(1025,))
|
||||
x = mx.broadcast_to(mx.random.uniform(shape=(2, 1, 8)), (2, 2, 8))
|
||||
self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))
|
||||
|
||||
def test_mean(self):
|
||||
x = mx.array(
|
||||
@@ -1643,6 +1662,15 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
x = mx.full((n,), vals=-float("inf"))
|
||||
self.assertTrue(mx.all(mx.isnan(mx.softmax(x))))
|
||||
|
||||
# Transposed inputs
|
||||
a = mx.random.uniform(shape=(32, 32, 32))
|
||||
b = mx.softmax(a, axis=-1)
|
||||
c = mx.softmax(a.swapaxes(0, 1), axis=-1).swapaxes(0, 1)
|
||||
self.assertEqual((b - c).abs().max().item(), 0.0)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.softmax(mx.array(1.0), axis=-1)
|
||||
|
||||
def test_concatenate(self):
|
||||
a_npy = np.random.randn(32, 32, 32)
|
||||
b_npy = np.random.randn(32, 32, 32)
|
||||
|
Reference in New Issue
Block a user