mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-23 10:02:12 +08:00
fix tolist for half types (#702)
This commit is contained in:
parent
f883fcede0
commit
d0fda82595
@ -23,15 +23,15 @@ enum PyScalarT {
|
||||
pycomplex = 3,
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
template <typename T, typename U = T>
|
||||
py::list to_list(array& a, size_t index, int dim) {
|
||||
py::list pl;
|
||||
auto stride = a.strides()[dim];
|
||||
for (int i = 0; i < a.shape(dim); ++i) {
|
||||
if (dim == a.ndim() - 1) {
|
||||
pl.append((a.data<T>()[index]));
|
||||
pl.append(static_cast<U>(a.data<T>()[index]));
|
||||
} else {
|
||||
pl.append(to_list<T>(a, index, dim + 1));
|
||||
pl.append(to_list<T, U>(a, index, dim + 1));
|
||||
}
|
||||
index += stride;
|
||||
}
|
||||
@ -102,11 +102,11 @@ py::object tolist(array& a) {
|
||||
case int64:
|
||||
return to_list<int64_t>(a, 0, 0);
|
||||
case float16:
|
||||
return to_list<float16_t>(a, 0, 0);
|
||||
return to_list<float16_t, float>(a, 0, 0);
|
||||
case float32:
|
||||
return to_list<float>(a, 0, 0);
|
||||
case bfloat16:
|
||||
return to_list<float16_t>(a, 0, 0);
|
||||
return to_list<bfloat16_t, float>(a, 0, 0);
|
||||
case complex64:
|
||||
return to_list<std::complex<float>>(a, 0, 0);
|
||||
}
|
||||
|
@ -431,6 +431,14 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
x = mx.array(vals)
|
||||
self.assertEqual(x.tolist(), vals)
|
||||
|
||||
# Half types
|
||||
vals = [1.0, 2.0, 3.0, 4.0, 5.0]
|
||||
x = mx.array(vals, dtype=mx.float16)
|
||||
self.assertEqual(x.tolist(), vals)
|
||||
|
||||
x = mx.array(vals, dtype=mx.bfloat16)
|
||||
self.assertEqual(x.tolist(), vals)
|
||||
|
||||
def test_array_np_conversion(self):
|
||||
# Shape test
|
||||
a = np.array([])
|
||||
|
Loading…
Reference in New Issue
Block a user