Fix boolean all reduce bug (#1355)

This commit is contained in:
Angelos Katharopoulos
2024-08-24 10:09:32 -07:00
committed by GitHub
parent 64bec4fad7
commit 8081df79be
4 changed files with 14 additions and 3 deletions

View File

@@ -124,6 +124,13 @@ class TestReduce(mlx_tests.MLXTestCase):
z = np.array(x).sum((0, 2, 3))
self.assertTrue(np.all(z == y))
def test_sum_bool(self):
x = np.random.uniform(0, 1, size=(10, 10, 10)) > 0.5
y = mx.array(x)
npsum = x.sum().item()
mxsum = y.sum().item()
self.assertEqual(npsum, mxsum)
if __name__ == "__main__":
unittest.main(failfast=True)