fix tolist for half types (#702)

This commit is contained in:
Awni Hannun
2024-02-19 09:44:27 -08:00
committed by GitHub
parent f883fcede0
commit d0fda82595
2 changed files with 13 additions and 5 deletions

View File

@@ -431,6 +431,14 @@ class TestArray(mlx_tests.MLXTestCase):
x = mx.array(vals)
self.assertEqual(x.tolist(), vals)
# Half types
vals = [1.0, 2.0, 3.0, 4.0, 5.0]
x = mx.array(vals, dtype=mx.float16)
self.assertEqual(x.tolist(), vals)
x = mx.array(vals, dtype=mx.bfloat16)
self.assertEqual(x.tolist(), vals)
def test_array_np_conversion(self):
# Shape test
a = np.array([])