fix tolist for half types (#702)

This commit is contained in:
Awni Hannun
2024-02-19 09:44:27 -08:00
committed by GitHub
parent f883fcede0
commit d0fda82595
2 changed files with 13 additions and 5 deletions

View File

@@ -23,15 +23,15 @@ enum PyScalarT {
pycomplex = 3,
};
template <typename T>
template <typename T, typename U = T>
py::list to_list(array& a, size_t index, int dim) {
py::list pl;
auto stride = a.strides()[dim];
for (int i = 0; i < a.shape(dim); ++i) {
if (dim == a.ndim() - 1) {
pl.append((a.data<T>()[index]));
pl.append(static_cast<U>(a.data<T>()[index]));
} else {
pl.append(to_list<T>(a, index, dim + 1));
pl.append(to_list<T, U>(a, index, dim + 1));
}
index += stride;
}
@@ -102,11 +102,11 @@ py::object tolist(array& a) {
case int64:
return to_list<int64_t>(a, 0, 0);
case float16:
return to_list<float16_t>(a, 0, 0);
return to_list<float16_t, float>(a, 0, 0);
case float32:
return to_list<float>(a, 0, 0);
case bfloat16:
return to_list<float16_t>(a, 0, 0);
return to_list<bfloat16_t, float>(a, 0, 0);
case complex64:
return to_list<std::complex<float>>(a, 0, 0);
}