diff --git a/python/src/array.cpp b/python/src/array.cpp index 57b867dbc..6dd2f290b 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -23,15 +23,15 @@ enum PyScalarT { pycomplex = 3, }; -template +template py::list to_list(array& a, size_t index, int dim) { py::list pl; auto stride = a.strides()[dim]; for (int i = 0; i < a.shape(dim); ++i) { if (dim == a.ndim() - 1) { - pl.append((a.data()[index])); + pl.append(static_cast(a.data()[index])); } else { - pl.append(to_list(a, index, dim + 1)); + pl.append(to_list(a, index, dim + 1)); } index += stride; } @@ -102,11 +102,11 @@ py::object tolist(array& a) { case int64: return to_list(a, 0, 0); case float16: - return to_list(a, 0, 0); + return to_list(a, 0, 0); case float32: return to_list(a, 0, 0); case bfloat16: - return to_list(a, 0, 0); + return to_list(a, 0, 0); case complex64: return to_list>(a, 0, 0); } diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 507675d6e..7812642d3 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -431,6 +431,14 @@ class TestArray(mlx_tests.MLXTestCase): x = mx.array(vals) self.assertEqual(x.tolist(), vals) + # Half types + vals = [1.0, 2.0, 3.0, 4.0, 5.0] + x = mx.array(vals, dtype=mx.float16) + self.assertEqual(x.tolist(), vals) + + x = mx.array(vals, dtype=mx.bfloat16) + self.assertEqual(x.tolist(), vals) + def test_array_np_conversion(self): # Shape test a = np.array([])