* fix scan

* improve grid size

* fix cpu cummax
This commit is contained in:
Awni Hannun
2024-06-05 14:21:58 -07:00
committed by GitHub
parent 0fe6895893
commit 496315fe1d
4 changed files with 23 additions and 8 deletions

View File

@@ -1678,7 +1678,9 @@ class TestOps(mlx_tests.MLXTestCase):
c_mlx = mxop(a_mlx, axis=axis)
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
a_mlx = mx.random.randint(shape=(32, 32, 32), low=-100, high=100)
for op in ["cumsum", "cumprod", "cummax", "cummin"]:
mxop = getattr(mx, op)
c1 = mxop(a_mlx, axis=2)
c2 = mxop(a_mlx, axis=2, inclusive=False, reverse=False)
self.assertTrue(mx.array_equal(c1[:, :, :-1], c2[:, :, 1:]))
@@ -1719,6 +1721,18 @@ class TestOps(mlx_tests.MLXTestCase):
out = mx.cumsum(a_t, axis=-1)
expected = (mat_t * a_t[:, None, :]).sum(axis=-1)
self.assertTrue(mx.allclose(out, expected, rtol=1e-2, atol=1e-3))
sizes = [1023, 1024, 1025, 2047, 2048, 2049]
for s in sizes:
a = mx.ones((s,), mx.int32)
out = mx.cumsum(a)
expected = mx.arange(1, s + 1, dtype=mx.int32)
self.assertTrue(mx.array_equal(expected, out))
# non-contiguous scan
a = mx.ones((s, 2), mx.int32)
out = mx.cumsum(a, axis=0)
expected = mx.repeat(expected[:, None], 2, axis=1)
self.assertTrue(mx.array_equal(expected, out))
def test_squeeze_expand(self):
a = mx.zeros((2, 1, 2, 1))