Bitwise Inverse (#1862)

* add bitwise inverse

* add vmap + fix nojit

* inverse -> invert

* add to compile + remove unused
This commit is contained in:
Alex Barron
2025-02-13 08:44:14 -08:00
committed by GitHub
parent e425dc00c0
commit 5cd97f7ffe
19 changed files with 147 additions and 8 deletions

View File

@@ -2573,6 +2573,18 @@ class TestOps(mlx_tests.MLXTestCase):
out_np = getattr(np, op)(a_np, b_np)
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
for t in types:
a_mlx = a.astype(t)
a_np = np.array(a_mlx)
out_mlx = ~a_mlx
out_np = ~a_np
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
out_mlx = mx.bitwise_invert(a_mlx)
out_np = mx.bitwise_invert(a_np)
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
# Check broadcasting
a = mx.ones((3, 1, 5), dtype=mx.bool_)
b = mx.zeros((1, 2, 5), dtype=mx.bool_)