Add bitwise ops (#1037)

* bitwise ops

* fix tests
This commit is contained in:
Awni Hannun
2024-04-26 22:03:42 -07:00
committed by GitHub
parent 67d1894759
commit 86f495985b
17 changed files with 568 additions and 58 deletions

View File

@@ -2177,6 +2177,38 @@ class TestOps(mlx_tests.MLXTestCase):
f"mx and np don't aggree on {a}, {b}",
)
def test_bitwise_ops(self):
types = [
mx.uint8,
mx.uint16,
mx.uint32,
mx.uint64,
mx.int8,
mx.int16,
mx.int32,
mx.int64,
]
a = mx.random.randint(0, 4096, (1000,))
b = mx.random.randint(0, 4096, (1000,))
for op in ["bitwise_and", "bitwise_or", "bitwise_xor"]:
for t in types:
a_mlx = a.astype(t)
b_mlx = b.astype(t)
a_np = np.array(a_mlx)
b_np = np.array(b_mlx)
out_mlx = getattr(mx, op)(a_mlx, b_mlx)
out_np = getattr(np, op)(a_np, b_np)
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
for op in ["left_shift", "right_shift"]:
for t in types:
a_mlx = a.astype(t)
b_mlx = mx.random.randint(0, t.size, (1000,)).astype(t)
a_np = np.array(a_mlx)
b_np = np.array(b_mlx)
out_mlx = getattr(mx, op)(a_mlx, b_mlx)
out_np = getattr(np, op)(a_np, b_np)
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
if __name__ == "__main__":
unittest.main()