mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-03 18:18:15 +08:00
Refactor reductions and fix scatter atomics for large sizes (#1300)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -43,6 +43,10 @@ class TestReduce(mlx_tests.MLXTestCase):
|
||||
z_npy = np.sum(y_npy, axis=a) / 1000
|
||||
z_mlx = mx.sum(y_mlx, axis=a) / 1000
|
||||
mx.eval(z_mlx)
|
||||
if not np.allclose(z_npy, np.array(z_mlx), atol=1e-4):
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
self.assertTrue(
|
||||
np.allclose(z_npy, np.array(z_mlx), atol=1e-4)
|
||||
)
|
||||
@@ -57,6 +61,7 @@ class TestReduce(mlx_tests.MLXTestCase):
|
||||
"uint32",
|
||||
"int64",
|
||||
"uint64",
|
||||
"complex64",
|
||||
]
|
||||
float_dtypes = ["float32"]
|
||||
|
||||
@@ -114,6 +119,15 @@ class TestReduce(mlx_tests.MLXTestCase):
|
||||
b = getattr(np, op)(data)
|
||||
self.assertEqual(a.item(), b)
|
||||
|
||||
def test_edge_case(self):
|
||||
x = (mx.random.normal((100, 1, 100, 100)) * 128).astype(mx.int32)
|
||||
x = x.transpose(0, 3, 1, 2)
|
||||
|
||||
y = x.sum((0, 2, 3))
|
||||
mx.eval(y)
|
||||
z = np.array(x).sum((0, 2, 3))
|
||||
self.assertTrue(np.all(z == y))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(failfast=True)
|
||||
|
||||
Reference in New Issue
Block a user