mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	fix tolist for half types (#702)
This commit is contained in:
		@@ -23,15 +23,15 @@ enum PyScalarT {
 | 
			
		||||
  pycomplex = 3,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
template <typename T, typename U = T>
 | 
			
		||||
py::list to_list(array& a, size_t index, int dim) {
 | 
			
		||||
  py::list pl;
 | 
			
		||||
  auto stride = a.strides()[dim];
 | 
			
		||||
  for (int i = 0; i < a.shape(dim); ++i) {
 | 
			
		||||
    if (dim == a.ndim() - 1) {
 | 
			
		||||
      pl.append((a.data<T>()[index]));
 | 
			
		||||
      pl.append(static_cast<U>(a.data<T>()[index]));
 | 
			
		||||
    } else {
 | 
			
		||||
      pl.append(to_list<T>(a, index, dim + 1));
 | 
			
		||||
      pl.append(to_list<T, U>(a, index, dim + 1));
 | 
			
		||||
    }
 | 
			
		||||
    index += stride;
 | 
			
		||||
  }
 | 
			
		||||
@@ -102,11 +102,11 @@ py::object tolist(array& a) {
 | 
			
		||||
    case int64:
 | 
			
		||||
      return to_list<int64_t>(a, 0, 0);
 | 
			
		||||
    case float16:
 | 
			
		||||
      return to_list<float16_t>(a, 0, 0);
 | 
			
		||||
      return to_list<float16_t, float>(a, 0, 0);
 | 
			
		||||
    case float32:
 | 
			
		||||
      return to_list<float>(a, 0, 0);
 | 
			
		||||
    case bfloat16:
 | 
			
		||||
      return to_list<float16_t>(a, 0, 0);
 | 
			
		||||
      return to_list<bfloat16_t, float>(a, 0, 0);
 | 
			
		||||
    case complex64:
 | 
			
		||||
      return to_list<std::complex<float>>(a, 0, 0);
 | 
			
		||||
  }
 | 
			
		||||
 
 | 
			
		||||
@@ -431,6 +431,14 @@ class TestArray(mlx_tests.MLXTestCase):
 | 
			
		||||
        x = mx.array(vals)
 | 
			
		||||
        self.assertEqual(x.tolist(), vals)
 | 
			
		||||
 | 
			
		||||
        # Half types
 | 
			
		||||
        vals = [1.0, 2.0, 3.0, 4.0, 5.0]
 | 
			
		||||
        x = mx.array(vals, dtype=mx.float16)
 | 
			
		||||
        self.assertEqual(x.tolist(), vals)
 | 
			
		||||
 | 
			
		||||
        x = mx.array(vals, dtype=mx.bfloat16)
 | 
			
		||||
        self.assertEqual(x.tolist(), vals)
 | 
			
		||||
 | 
			
		||||
    def test_array_np_conversion(self):
 | 
			
		||||
        # Shape test
 | 
			
		||||
        a = np.array([])
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user