fix broadcast bug in bitwise ops (#1157)

This commit is contained in:
Awni Hannun
2024-05-24 11:44:40 -07:00
committed by GitHub
parent 9f9cb7a2ef
commit a87ef5bfc1
3 changed files with 12 additions and 4 deletions

View File

@@ -2291,6 +2291,13 @@ class TestOps(mlx_tests.MLXTestCase):
out_np = getattr(np, op)(a_np, b_np)
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
# Check broadcasting
a = mx.ones((3, 1, 5), dtype=mx.bool_)
b = mx.zeros((1, 2, 5), dtype=mx.bool_)
c = a | b
self.assertEqual(c.shape, (3, 2, 5))
self.assertTrue(mx.array_equal(c, mx.ones((3, 2, 5), dtype=mx.bool_)))
def test_conjugate(self):
shape = (3, 5, 7)
a = np.random.normal(size=shape) + 1j * np.random.normal(size=shape)