mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-02 13:54:44 +08:00
Working 64-bit scans (#1506)
This commit is contained in:

committed by
GitHub

parent
32972a5924
commit
c9b41d460f
@@ -1758,6 +1758,18 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
c_mlx = mxop(a_mlx, axis=axis)
|
||||
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
|
||||
|
||||
a_mlx = mx.random.randint(shape=(32, 32, 32), low=-100, high=100)
|
||||
for dt in [mx.int32, mx.int64]:
|
||||
mxx = a_mlx.astype(dt)
|
||||
npx = np.array(mxx)
|
||||
for op in ["cumsum", "cumprod"]:
|
||||
npop = getattr(np, op)
|
||||
mxop = getattr(mx, op)
|
||||
for axis in (None, 0, 1, 2):
|
||||
c_npy = npop(npx, axis=axis, dtype=npx.dtype)
|
||||
c_mlx = mxop(mxx, axis=axis)
|
||||
self.assertTrue(np.array_equal(c_npy, c_mlx))
|
||||
|
||||
a_mlx = mx.random.randint(shape=(32, 32, 32), low=-100, high=100)
|
||||
for op in ["cumsum", "cumprod", "cummax", "cummin"]:
|
||||
mxop = getattr(mx, op)
|
||||
|
Reference in New Issue
Block a user