mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Custom logsumexp (#2028)
* initial custom logsumexp * more tests * comments + fix
This commit is contained in:
		| @@ -690,15 +690,34 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|         self.assertTrue(np.array_equal(b_npy, b_mlx)) | ||||
|  | ||||
|     def test_logsumexp(self): | ||||
|         def logsumexp(x, axes=None): | ||||
|             maxs = mx.max(x, axis=axes, keepdims=True) | ||||
|             return mx.log(mx.sum(mx.exp(x - maxs), axis=axes, keepdims=True)) + maxs | ||||
|  | ||||
|         x = mx.array( | ||||
|             [ | ||||
|                 [1.0, 2.0], | ||||
|                 [3.0, 4.0], | ||||
|             ] | ||||
|         ) | ||||
|         xnp = np.array(x.tolist(), dtype=np.float32) | ||||
|         expected = np.log(np.sum(np.exp(xnp))) | ||||
|         self.assertTrue(math.isclose(mx.logsumexp(x).item(), expected.item())) | ||||
|         self.assertTrue(math.isclose(mx.logsumexp(x).item(), logsumexp(x).item())) | ||||
|  | ||||
|         x = mx.random.uniform(shape=(1025,)) | ||||
|         self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x))) | ||||
|  | ||||
|         # Transposed | ||||
|         x = mx.random.uniform(shape=(2, 2, 8)) | ||||
|         x = x.swapaxes(0, 1) | ||||
|         self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x))) | ||||
|  | ||||
|         # Broadcast | ||||
|         x = mx.broadcast_to(mx.random.uniform(shape=(2, 1, 8)), (2, 2, 8)) | ||||
|         self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x))) | ||||
|  | ||||
|         # Large | ||||
|         x = mx.random.uniform(shape=(1025,)) | ||||
|         x = mx.broadcast_to(mx.random.uniform(shape=(2, 1, 8)), (2, 2, 8)) | ||||
|         self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x))) | ||||
|  | ||||
|     def test_mean(self): | ||||
|         x = mx.array( | ||||
| @@ -1643,6 +1662,15 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|             x = mx.full((n,), vals=-float("inf")) | ||||
|             self.assertTrue(mx.all(mx.isnan(mx.softmax(x)))) | ||||
|  | ||||
|         # Transposed inputs | ||||
|         a = mx.random.uniform(shape=(32, 32, 32)) | ||||
|         b = mx.softmax(a, axis=-1) | ||||
|         c = mx.softmax(a.swapaxes(0, 1), axis=-1).swapaxes(0, 1) | ||||
|         self.assertEqual((b - c).abs().max().item(), 0.0) | ||||
|  | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.softmax(mx.array(1.0), axis=-1) | ||||
|  | ||||
|     def test_concatenate(self): | ||||
|         a_npy = np.random.randn(32, 32, 32) | ||||
|         b_npy = np.random.randn(32, 32, 32) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun