fix tolist for half types (#702)

This commit is contained in:
Awni Hannun 2024-02-19 09:44:27 -08:00 committed by GitHub
parent f883fcede0
commit d0fda82595
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 5 deletions

View File

@ -23,15 +23,15 @@ enum PyScalarT {
pycomplex = 3,
};
template <typename T>
template <typename T, typename U = T>
py::list to_list(array& a, size_t index, int dim) {
py::list pl;
auto stride = a.strides()[dim];
for (int i = 0; i < a.shape(dim); ++i) {
if (dim == a.ndim() - 1) {
pl.append((a.data<T>()[index]));
pl.append(static_cast<U>(a.data<T>()[index]));
} else {
pl.append(to_list<T>(a, index, dim + 1));
pl.append(to_list<T, U>(a, index, dim + 1));
}
index += stride;
}
@ -102,11 +102,11 @@ py::object tolist(array& a) {
case int64:
return to_list<int64_t>(a, 0, 0);
case float16:
return to_list<float16_t>(a, 0, 0);
return to_list<float16_t, float>(a, 0, 0);
case float32:
return to_list<float>(a, 0, 0);
case bfloat16:
return to_list<float16_t>(a, 0, 0);
return to_list<bfloat16_t, float>(a, 0, 0);
case complex64:
return to_list<std::complex<float>>(a, 0, 0);
}

View File

@ -431,6 +431,14 @@ class TestArray(mlx_tests.MLXTestCase):
x = mx.array(vals)
self.assertEqual(x.tolist(), vals)
# Half types
vals = [1.0, 2.0, 3.0, 4.0, 5.0]
x = mx.array(vals, dtype=mx.float16)
self.assertEqual(x.tolist(), vals)
x = mx.array(vals, dtype=mx.bfloat16)
self.assertEqual(x.tolist(), vals)
def test_array_np_conversion(self):
# Shape test
a = np.array([])