Working 64-bit scans (#1506)

This commit is contained in:
Angelos Katharopoulos
2024-10-24 11:05:46 -07:00
committed by GitHub
parent 32972a5924
commit c9b41d460f
9 changed files with 309 additions and 137 deletions

View File

@@ -1758,6 +1758,18 @@ class TestOps(mlx_tests.MLXTestCase):
c_mlx = mxop(a_mlx, axis=axis)
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
a_mlx = mx.random.randint(shape=(32, 32, 32), low=-100, high=100)
for dt in [mx.int32, mx.int64]:
mxx = a_mlx.astype(dt)
npx = np.array(mxx)
for op in ["cumsum", "cumprod"]:
npop = getattr(np, op)
mxop = getattr(mx, op)
for axis in (None, 0, 1, 2):
c_npy = npop(npx, axis=axis, dtype=npx.dtype)
c_mlx = mxop(mxx, axis=axis)
self.assertTrue(np.array_equal(c_npy, c_mlx))
a_mlx = mx.random.randint(shape=(32, 32, 32), low=-100, high=100)
for op in ["cumsum", "cumprod", "cummax", "cummin"]:
mxop = getattr(mx, op)