mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
fix tolist for half types (#702)
This commit is contained in:
@@ -23,15 +23,15 @@ enum PyScalarT {
|
||||
pycomplex = 3,
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
template <typename T, typename U = T>
|
||||
py::list to_list(array& a, size_t index, int dim) {
|
||||
py::list pl;
|
||||
auto stride = a.strides()[dim];
|
||||
for (int i = 0; i < a.shape(dim); ++i) {
|
||||
if (dim == a.ndim() - 1) {
|
||||
pl.append((a.data<T>()[index]));
|
||||
pl.append(static_cast<U>(a.data<T>()[index]));
|
||||
} else {
|
||||
pl.append(to_list<T>(a, index, dim + 1));
|
||||
pl.append(to_list<T, U>(a, index, dim + 1));
|
||||
}
|
||||
index += stride;
|
||||
}
|
||||
@@ -102,11 +102,11 @@ py::object tolist(array& a) {
|
||||
case int64:
|
||||
return to_list<int64_t>(a, 0, 0);
|
||||
case float16:
|
||||
return to_list<float16_t>(a, 0, 0);
|
||||
return to_list<float16_t, float>(a, 0, 0);
|
||||
case float32:
|
||||
return to_list<float>(a, 0, 0);
|
||||
case bfloat16:
|
||||
return to_list<float16_t>(a, 0, 0);
|
||||
return to_list<bfloat16_t, float>(a, 0, 0);
|
||||
case complex64:
|
||||
return to_list<std::complex<float>>(a, 0, 0);
|
||||
}
|
||||
|
Reference in New Issue
Block a user