mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Complex scan (#2094)
This commit is contained in:
		| @@ -10,6 +10,47 @@ import mlx_tests | ||||
| import numpy as np | ||||
|  | ||||
|  | ||||
| def np_wrap_between(x, a): | ||||
|     """Wraps `x` between `[-a, a]`.""" | ||||
|     two_a = 2 * a | ||||
|     zero = 0 | ||||
|     rem = np.remainder(np.add(x, a), two_a) | ||||
|     if isinstance(rem, np.ndarray): | ||||
|         rem = np.select(rem < zero, np.add(rem, two_a), rem) | ||||
|     else: | ||||
|         rem = np.add(rem, two_a) if rem < zero else rem | ||||
|     return np.subtract(rem, a) | ||||
|  | ||||
|  | ||||
| def np_logaddexp(x1: np.ndarray, x2: np.ndarray): | ||||
|     amax = np.maximum(x1, x2) | ||||
|     if np.issubdtype(x1.dtype, np.floating): | ||||
|         delta = np.subtract(x1, x2) | ||||
|         if isinstance(delta, np.ndarray): | ||||
|             return np.select( | ||||
|                 np.isnan(delta), | ||||
|                 np.add(x1, x2), | ||||
|                 np.add(amax, np.log1p(np.exp(np.negative(np.abs(delta))))), | ||||
|             ) | ||||
|         else: | ||||
|             return ( | ||||
|                 np.add(x1, x2) | ||||
|                 if np.isnan(delta) | ||||
|                 else np.add(amax, np.log1p(np.exp(np.negative(np.abs(delta))))) | ||||
|             ) | ||||
|     else: | ||||
|         delta = np.subtract(np.add(x1, x2), np.multiply(amax, 2)) | ||||
|         out = np.add(amax, np.log1p(np.exp(delta))) | ||||
|         return np.real(out) + 1j * np_wrap_between(np.imag(out), np.pi) | ||||
|  | ||||
|  | ||||
| def np_cumlogaddexp(x1: np.ndarray, axis: int = -1): | ||||
|     out = x1.copy() | ||||
|     for i in range(1, out.shape[axis]): | ||||
|         out[i] = np_logaddexp(out[i], out[i - 1]) | ||||
|     return out | ||||
|  | ||||
|  | ||||
| class TestOps(mlx_tests.MLXTestCase): | ||||
|     def test_full_ones_zeros(self): | ||||
|         x = mx.full(2, 3.0) | ||||
| @@ -853,6 +894,16 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|  | ||||
|         self.assertTrue(np.allclose(result, expected)) | ||||
|  | ||||
|         # Complex test | ||||
|  | ||||
|         a = mx.array([0, 1, 2, 9.0]) + 1j | ||||
|         b = mx.array([1, 0, 4, 2.5]) + 1j | ||||
|  | ||||
|         result = mx.logaddexp(a, b) | ||||
|         expected = np_logaddexp(np.array(a), np.array(b)) | ||||
|  | ||||
|         self.assertTrue(np.allclose(result, expected)) | ||||
|  | ||||
|         a = mx.array([float("nan")]) | ||||
|         b = mx.array([0.0]) | ||||
|         self.assertTrue(math.isnan(mx.logaddexp(a, b).item())) | ||||
| @@ -977,6 +1028,13 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|  | ||||
|         self.assertTrue(np.allclose(result, expected)) | ||||
|  | ||||
|         # Complex test | ||||
|         a = mx.array([1, 0.5, 10, 100]) + 1j | ||||
|         result = mx.log1p(a) | ||||
|         expected = np.log1p(a, dtype=np.complex64) | ||||
|  | ||||
|         self.assertTrue(np.allclose(result, expected)) | ||||
|  | ||||
|     def test_sigmoid(self): | ||||
|         a = mx.array([0.0, 1.0, -1.0, 5.0, -5.0]) | ||||
|         result = mx.sigmoid(a) | ||||
| @@ -1881,10 +1939,31 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|             c_mlx = mxop(a_mlx, axis=0) | ||||
|             self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3)) | ||||
|  | ||||
|         # Complex tests | ||||
|  | ||||
|         a_npy = np.array([1, 2, 3]).astype(np.float32) + 1j | ||||
|         a_mlx = mx.array(a_npy) | ||||
|         c_npy = np_cumlogaddexp(a_npy, axis=-1) | ||||
|         c_mlx = mxop(a_mlx, axis=-1) | ||||
|         self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3)) | ||||
|  | ||||
|     def test_scans(self): | ||||
|         a_npy = np.random.randn(32, 32, 32).astype(np.float32) | ||||
|         a_mlx = mx.array(a_npy) | ||||
|  | ||||
|         for op in ["cumsum", "cumprod"]: | ||||
|             npop = getattr(np, op) | ||||
|             mxop = getattr(mx, op) | ||||
|             for axis in (None, 0, 1, 2): | ||||
|                 c_npy = npop(a_npy, axis=axis) | ||||
|                 c_mlx = mxop(a_mlx, axis=axis) | ||||
|                 self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3)) | ||||
|  | ||||
|         # Complex test | ||||
|  | ||||
|         a_npy = np.random.randn(32, 32, 32).astype(np.float32) + 0.5j | ||||
|         a_mlx = mx.array(a_npy) | ||||
|  | ||||
|         for op in ["cumsum", "cumprod"]: | ||||
|             npop = getattr(np, op) | ||||
|             mxop = getattr(mx, op) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Yury Popov
					Yury Popov