Fix view scalar bug segfault (#1603)

* fix view scalar bug

* fix view scalar bug

* one more fix
This commit is contained in:
Awni Hannun
2024-11-19 10:54:05 -08:00
committed by GitHub
parent 5e89aace9b
commit 61d787726a
4 changed files with 7 additions and 3 deletions

View File

@@ -2532,6 +2532,10 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
def test_view(self):
# Check scalar
out = mx.array(1, mx.int8).view(mx.uint8).item()
self.assertEqual(out, 1)
a = mx.random.randint(shape=(4, 2, 4), low=-100, high=100)
a_np = np.array(a)