mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	LogCumSumExp (#2069)
This commit is contained in:
		| @@ -1508,6 +1508,7 @@ class TestArray(mlx_tests.MLXTestCase): | ||||
|             ("prod", 1), | ||||
|             ("min", 1), | ||||
|             ("max", 1), | ||||
|             ("logcumsumexp", 1), | ||||
|             ("logsumexp", 1), | ||||
|             ("mean", 1), | ||||
|             ("var", 1), | ||||
|   | ||||
| @@ -1857,6 +1857,30 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|         y = mx.as_strided(x, (x.size,), (-1,), x.size - 1) | ||||
|         self.assertTrue(mx.array_equal(y, x[::-1])) | ||||
|  | ||||
|     def test_logcumsumexp(self): | ||||
|         npop = np.logaddexp.accumulate | ||||
|         mxop = mx.logcumsumexp | ||||
|  | ||||
|         a_npy = np.random.randn(32, 32, 32).astype(np.float32) | ||||
|         a_mlx = mx.array(a_npy) | ||||
|  | ||||
|         for axis in (0, 1, 2): | ||||
|             c_npy = npop(a_npy, axis=axis) | ||||
|             c_mlx = mxop(a_mlx, axis=axis) | ||||
|             self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3)) | ||||
|  | ||||
|         edge_cases_npy = [ | ||||
|             np.float32([-float("inf")] * 8), | ||||
|             np.float32([-float("inf"), 0, -float("inf")]), | ||||
|             np.float32([-float("inf"), float("inf"), -float("inf")]), | ||||
|         ] | ||||
|         edge_cases_mlx = [mx.array(a) for a in edge_cases_npy] | ||||
|  | ||||
|         for a_npy, a_mlx in zip(edge_cases_npy, edge_cases_mlx): | ||||
|             c_npy = npop(a_npy, axis=0) | ||||
|             c_mlx = mxop(a_mlx, axis=0) | ||||
|             self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3)) | ||||
|  | ||||
|     def test_scans(self): | ||||
|         a_npy = np.random.randn(32, 32, 32).astype(np.float32) | ||||
|         a_mlx = mx.array(a_npy) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Yury Popov
					Yury Popov