mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:29:35 +08:00
Refactor reductions and fix scatter atomics for large sizes (#1300)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -1763,7 +1763,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
mat_t = mat.astype(t)
|
||||
out = mx.cumsum(a_t, axis=-1)
|
||||
expected = (mat_t * a_t[:, None, :]).sum(axis=-1)
|
||||
self.assertTrue(mx.allclose(out, expected, rtol=1e-2, atol=1e-3))
|
||||
self.assertTrue(mx.allclose(out, expected, rtol=0.02, atol=1e-3))
|
||||
sizes = [1023, 1024, 1025, 2047, 2048, 2049]
|
||||
for s in sizes:
|
||||
a = mx.ones((s,), mx.int32)
|
||||
|
Reference in New Issue
Block a user