mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
@@ -1678,7 +1678,9 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
c_mlx = mxop(a_mlx, axis=axis)
|
||||
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
|
||||
|
||||
a_mlx = mx.random.randint(shape=(32, 32, 32), low=-100, high=100)
|
||||
for op in ["cumsum", "cumprod", "cummax", "cummin"]:
|
||||
mxop = getattr(mx, op)
|
||||
c1 = mxop(a_mlx, axis=2)
|
||||
c2 = mxop(a_mlx, axis=2, inclusive=False, reverse=False)
|
||||
self.assertTrue(mx.array_equal(c1[:, :, :-1], c2[:, :, 1:]))
|
||||
@@ -1719,6 +1721,18 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
out = mx.cumsum(a_t, axis=-1)
|
||||
expected = (mat_t * a_t[:, None, :]).sum(axis=-1)
|
||||
self.assertTrue(mx.allclose(out, expected, rtol=1e-2, atol=1e-3))
|
||||
sizes = [1023, 1024, 1025, 2047, 2048, 2049]
|
||||
for s in sizes:
|
||||
a = mx.ones((s,), mx.int32)
|
||||
out = mx.cumsum(a)
|
||||
expected = mx.arange(1, s + 1, dtype=mx.int32)
|
||||
self.assertTrue(mx.array_equal(expected, out))
|
||||
|
||||
# non-contiguous scan
|
||||
a = mx.ones((s, 2), mx.int32)
|
||||
out = mx.cumsum(a, axis=0)
|
||||
expected = mx.repeat(expected[:, None], 2, axis=1)
|
||||
self.assertTrue(mx.array_equal(expected, out))
|
||||
|
||||
def test_squeeze_expand(self):
|
||||
a = mx.zeros((2, 1, 2, 1))
|
||||
|
Reference in New Issue
Block a user