Add view op (#1179)

* add view primitive

* nit

* fix view
This commit is contained in:
Awni Hannun
2024-06-04 08:05:27 -07:00
committed by GitHub
parent 81def6ac76
commit ea9090bbc4
14 changed files with 202 additions and 11 deletions

View File

@@ -2333,6 +2333,30 @@ class TestOps(mlx_tests.MLXTestCase):
out_np = a.conj()
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
def test_view(self):
a = mx.random.randint(shape=(4, 2, 4), low=-100, high=100)
a_np = np.array(a)
for t in ["bool_", "int16", "float32", "int64"]:
out = a.view(getattr(mx, t))
expected = a_np.view(getattr(np, t))
self.assertTrue(np.array_equal(out, expected, equal_nan=True))
# Irregular strides
a = mx.random.randint(shape=(2, 4), low=-100, high=100)
a = mx.broadcast_to(a, shape=(4, 2, 4))
for t in ["bool_", "int16", "float32", "int64"]:
out = a.view(getattr(mx, t))
a_out = out.view(mx.int32)
self.assertTrue(mx.array_equal(a_out, a, equal_nan=True))
a = mx.random.randint(shape=(4, 4), low=-100, high=100).T
for t in ["bool_", "int16", "float32", "int64"]:
out = a.view(getattr(mx, t))
a_out = out.view(mx.int32)
self.assertTrue(mx.array_equal(a_out, a, equal_nan=True))
if __name__ == "__main__":
unittest.main()