Refactor reductions and fix scatter atomics for large sizes (#1300)

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-08-22 16:03:31 -07:00
committed by GitHub
parent f9e00efe31
commit 98b6ce3460
18 changed files with 1584 additions and 1235 deletions

View File

@@ -1763,7 +1763,7 @@ class TestOps(mlx_tests.MLXTestCase):
mat_t = mat.astype(t)
out = mx.cumsum(a_t, axis=-1)
expected = (mat_t * a_t[:, None, :]).sum(axis=-1)
self.assertTrue(mx.allclose(out, expected, rtol=1e-2, atol=1e-3))
self.assertTrue(mx.allclose(out, expected, rtol=0.02, atol=1e-3))
sizes = [1023, 1024, 1025, 2047, 2048, 2049]
for s in sizes:
a = mx.ones((s,), mx.int32)

View File

@@ -43,6 +43,10 @@ class TestReduce(mlx_tests.MLXTestCase):
z_npy = np.sum(y_npy, axis=a) / 1000
z_mlx = mx.sum(y_mlx, axis=a) / 1000
mx.eval(z_mlx)
if not np.allclose(z_npy, np.array(z_mlx), atol=1e-4):
import pdb
pdb.set_trace()
self.assertTrue(
np.allclose(z_npy, np.array(z_mlx), atol=1e-4)
)
@@ -57,6 +61,7 @@ class TestReduce(mlx_tests.MLXTestCase):
"uint32",
"int64",
"uint64",
"complex64",
]
float_dtypes = ["float32"]
@@ -114,6 +119,15 @@ class TestReduce(mlx_tests.MLXTestCase):
b = getattr(np, op)(data)
self.assertEqual(a.item(), b)
def test_edge_case(self):
x = (mx.random.normal((100, 1, 100, 100)) * 128).astype(mx.int32)
x = x.transpose(0, 3, 1, 2)
y = x.sum((0, 2, 3))
mx.eval(y)
z = np.array(x).sum((0, 2, 3))
self.assertTrue(np.all(z == y))
if __name__ == "__main__":
unittest.main(failfast=True)