mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
@@ -1564,6 +1564,15 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
c2 = mxop(a_mlx, axis=0, inclusive=False, reverse=True)[:-1, :, :]
|
||||
self.assertTrue(mx.array_equal(c1, c2))
|
||||
|
||||
a = mx.random.uniform(shape=(8, 32))
|
||||
mat = mx.tri(32)
|
||||
for t in [mx.float16, mx.bfloat16]:
|
||||
a_t = a.astype(t)
|
||||
mat_t = mat.astype(t)
|
||||
out = mx.cumsum(a_t, axis=-1)
|
||||
expected = (mat_t * a_t[:, None, :]).sum(axis=-1)
|
||||
self.assertTrue(mx.allclose(out, expected, rtol=1e-2, atol=1e-3))
|
||||
|
||||
def test_squeeze_expand(self):
|
||||
a = mx.zeros((2, 1, 2, 1))
|
||||
self.assertEqual(mx.squeeze(a).shape, (2, 2))
|
||||
|
Reference in New Issue
Block a user