mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:38:09 +08:00
Fix boolean all reduce bug (#1355)
This commit is contained in:

committed by
GitHub

parent
64bec4fad7
commit
8081df79be
@@ -124,6 +124,13 @@ class TestReduce(mlx_tests.MLXTestCase):
|
||||
z = np.array(x).sum((0, 2, 3))
|
||||
self.assertTrue(np.all(z == y))
|
||||
|
||||
def test_sum_bool(self):
|
||||
x = np.random.uniform(0, 1, size=(10, 10, 10)) > 0.5
|
||||
y = mx.array(x)
|
||||
npsum = x.sum().item()
|
||||
mxsum = y.sum().item()
|
||||
self.assertEqual(npsum, mxsum)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(failfast=True)
|
||||
|
Reference in New Issue
Block a user