mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-08 10:14:43 +08:00
@@ -2177,6 +2177,38 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
f"mx and np don't aggree on {a}, {b}",
|
||||
)
|
||||
|
||||
def test_bitwise_ops(self):
|
||||
types = [
|
||||
mx.uint8,
|
||||
mx.uint16,
|
||||
mx.uint32,
|
||||
mx.uint64,
|
||||
mx.int8,
|
||||
mx.int16,
|
||||
mx.int32,
|
||||
mx.int64,
|
||||
]
|
||||
a = mx.random.randint(0, 4096, (1000,))
|
||||
b = mx.random.randint(0, 4096, (1000,))
|
||||
for op in ["bitwise_and", "bitwise_or", "bitwise_xor"]:
|
||||
for t in types:
|
||||
a_mlx = a.astype(t)
|
||||
b_mlx = b.astype(t)
|
||||
a_np = np.array(a_mlx)
|
||||
b_np = np.array(b_mlx)
|
||||
out_mlx = getattr(mx, op)(a_mlx, b_mlx)
|
||||
out_np = getattr(np, op)(a_np, b_np)
|
||||
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
|
||||
for op in ["left_shift", "right_shift"]:
|
||||
for t in types:
|
||||
a_mlx = a.astype(t)
|
||||
b_mlx = mx.random.randint(0, t.size, (1000,)).astype(t)
|
||||
a_np = np.array(a_mlx)
|
||||
b_np = np.array(b_mlx)
|
||||
out_mlx = getattr(mx, op)(a_mlx, b_mlx)
|
||||
out_np = getattr(np, op)(a_np, b_np)
|
||||
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user