mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
fix broadcast bug in bitwise ops (#1157)
This commit is contained in:
@@ -2291,6 +2291,13 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
out_np = getattr(np, op)(a_np, b_np)
|
||||
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
|
||||
|
||||
# Check broadcasting
|
||||
a = mx.ones((3, 1, 5), dtype=mx.bool_)
|
||||
b = mx.zeros((1, 2, 5), dtype=mx.bool_)
|
||||
c = a | b
|
||||
self.assertEqual(c.shape, (3, 2, 5))
|
||||
self.assertTrue(mx.array_equal(c, mx.ones((3, 2, 5), dtype=mx.bool_)))
|
||||
|
||||
def test_conjugate(self):
|
||||
shape = (3, 5, 7)
|
||||
a = np.random.normal(size=shape) + 1j * np.random.normal(size=shape)
|
||||
|
Reference in New Issue
Block a user