mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	support python mlx.array creation from list of mlx.array's (#325)
* support python mlx.array creation from list of mlx.array's * include bfloat16 in UT * refactor so that sub array made of all python primitive types gets initialized by fill_vector * address PR comment: arr.shape().size() -> arr.ndim() * address PR comment: get back Dtype constness and let stack to handle type promotions automatically
This commit is contained in:
		@@ -120,7 +120,11 @@ void fill_vector(T list, std::vector<U>& vals) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
PyScalarT validate_shape(T list, const std::vector<int>& shape, int idx) {
 | 
			
		||||
PyScalarT validate_shape(
 | 
			
		||||
    T list,
 | 
			
		||||
    const std::vector<int>& shape,
 | 
			
		||||
    int idx,
 | 
			
		||||
    bool& all_python_primitive_elements) {
 | 
			
		||||
  if (idx >= shape.size()) {
 | 
			
		||||
    throw std::invalid_argument("Initialization encountered extra dimension.");
 | 
			
		||||
  }
 | 
			
		||||
@@ -138,9 +142,17 @@ PyScalarT validate_shape(T list, const std::vector<int>& shape, int idx) {
 | 
			
		||||
  for (auto l : list) {
 | 
			
		||||
    PyScalarT t;
 | 
			
		||||
    if (py::isinstance<py::list>(l)) {
 | 
			
		||||
      t = validate_shape(l.template cast<py::list>(), shape, idx + 1);
 | 
			
		||||
      t = validate_shape(
 | 
			
		||||
          l.template cast<py::list>(),
 | 
			
		||||
          shape,
 | 
			
		||||
          idx + 1,
 | 
			
		||||
          all_python_primitive_elements);
 | 
			
		||||
    } else if (py::isinstance<py::tuple>(*list.begin())) {
 | 
			
		||||
      t = validate_shape(l.template cast<py::tuple>(), shape, idx + 1);
 | 
			
		||||
      t = validate_shape(
 | 
			
		||||
          l.template cast<py::tuple>(),
 | 
			
		||||
          shape,
 | 
			
		||||
          idx + 1,
 | 
			
		||||
          all_python_primitive_elements);
 | 
			
		||||
    } else if (py::isinstance<py::bool_>(l)) {
 | 
			
		||||
      t = pybool;
 | 
			
		||||
    } else if (py::isinstance<py::int_>(l)) {
 | 
			
		||||
@@ -149,6 +161,19 @@ PyScalarT validate_shape(T list, const std::vector<int>& shape, int idx) {
 | 
			
		||||
      t = pyfloat;
 | 
			
		||||
    } else if (PyComplex_Check(l.ptr())) {
 | 
			
		||||
      t = pycomplex;
 | 
			
		||||
    } else if (py::isinstance<array>(l)) {
 | 
			
		||||
      all_python_primitive_elements = false;
 | 
			
		||||
      auto arr = py::cast<array>(l);
 | 
			
		||||
      if (arr.ndim() + idx + 1 == shape.size() &&
 | 
			
		||||
          std::equal(
 | 
			
		||||
              arr.shape().cbegin(),
 | 
			
		||||
              arr.shape().cend(),
 | 
			
		||||
              shape.cbegin() + idx + 1)) {
 | 
			
		||||
        t = pybool;
 | 
			
		||||
      } else {
 | 
			
		||||
        throw std::invalid_argument(
 | 
			
		||||
            "Initialization encountered non-uniform length.");
 | 
			
		||||
      }
 | 
			
		||||
    } else {
 | 
			
		||||
      std::ostringstream msg;
 | 
			
		||||
      msg << "Invalid type in array initialization" << l.get_type() << ".";
 | 
			
		||||
@@ -168,6 +193,64 @@ void get_shape(T list, std::vector<int>& shape) {
 | 
			
		||||
      return get_shape(l.template cast<py::list>(), shape);
 | 
			
		||||
    } else if (py::isinstance<py::tuple>(l)) {
 | 
			
		||||
      return get_shape(l.template cast<py::tuple>(), shape);
 | 
			
		||||
    } else if (py::isinstance<array>(l)) {
 | 
			
		||||
      auto arr = py::cast<array>(l);
 | 
			
		||||
      shape.insert(shape.end(), arr.shape().begin(), arr.shape().end());
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
using array_init_type = std::variant<
 | 
			
		||||
    py::bool_,
 | 
			
		||||
    py::int_,
 | 
			
		||||
    py::float_,
 | 
			
		||||
    std::complex<float>,
 | 
			
		||||
    py::list,
 | 
			
		||||
    py::tuple,
 | 
			
		||||
    py::array,
 | 
			
		||||
    py::buffer,
 | 
			
		||||
    py::object>;
 | 
			
		||||
 | 
			
		||||
// Forward declaration
 | 
			
		||||
array create_array(array_init_type v, std::optional<Dtype> t);
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
array array_from_list(
 | 
			
		||||
    T pl,
 | 
			
		||||
    const PyScalarT& inferred_type,
 | 
			
		||||
    std::optional<Dtype> specified_type,
 | 
			
		||||
    const std::vector<int>& shape) {
 | 
			
		||||
  // Make the array
 | 
			
		||||
  switch (inferred_type) {
 | 
			
		||||
    case pybool: {
 | 
			
		||||
      std::vector<bool> vals;
 | 
			
		||||
      fill_vector(pl, vals);
 | 
			
		||||
      return array(vals.begin(), shape, specified_type.value_or(bool_));
 | 
			
		||||
    }
 | 
			
		||||
    case pyint: {
 | 
			
		||||
      std::vector<int> vals;
 | 
			
		||||
      fill_vector(pl, vals);
 | 
			
		||||
      return array(vals.begin(), shape, specified_type.value_or(int32));
 | 
			
		||||
    }
 | 
			
		||||
    case pyfloat: {
 | 
			
		||||
      std::vector<float> vals;
 | 
			
		||||
      fill_vector(pl, vals);
 | 
			
		||||
      return array(vals.begin(), shape, specified_type.value_or(float32));
 | 
			
		||||
    }
 | 
			
		||||
    case pycomplex: {
 | 
			
		||||
      std::vector<std::complex<float>> vals;
 | 
			
		||||
      fill_vector(pl, vals);
 | 
			
		||||
      return array(
 | 
			
		||||
          reinterpret_cast<complex64_t*>(vals.data()),
 | 
			
		||||
          shape,
 | 
			
		||||
          specified_type.value_or(complex64));
 | 
			
		||||
    }
 | 
			
		||||
    default: {
 | 
			
		||||
      std::ostringstream msg;
 | 
			
		||||
      msg << "Should not happen, inferred: " << inferred_type
 | 
			
		||||
          << " on subarray made of only python primitive types.";
 | 
			
		||||
      throw std::runtime_error(msg.str());
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
@@ -179,39 +262,20 @@ array array_from_list(T pl, std::optional<Dtype> dtype) {
 | 
			
		||||
  get_shape(pl, shape);
 | 
			
		||||
 | 
			
		||||
  // Validate the shape and type
 | 
			
		||||
  auto type = validate_shape(pl, shape, 0);
 | 
			
		||||
  bool all_python_primitive_elements = true;
 | 
			
		||||
  auto type = validate_shape(pl, shape, 0, all_python_primitive_elements);
 | 
			
		||||
 | 
			
		||||
  size_t size = 1;
 | 
			
		||||
  for (auto s : shape) {
 | 
			
		||||
    size *= s;
 | 
			
		||||
  if (all_python_primitive_elements) {
 | 
			
		||||
    // `pl` does not contain mlx arrays
 | 
			
		||||
    return array_from_list(pl, type, dtype, shape);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Make the array
 | 
			
		||||
  switch (type) {
 | 
			
		||||
    case pybool: {
 | 
			
		||||
      std::vector<bool> vals;
 | 
			
		||||
      fill_vector(pl, vals);
 | 
			
		||||
      return array(vals.begin(), shape, dtype.value_or(bool_));
 | 
			
		||||
    }
 | 
			
		||||
    case pyint: {
 | 
			
		||||
      std::vector<int> vals;
 | 
			
		||||
      fill_vector(pl, vals);
 | 
			
		||||
      return array(vals.begin(), shape, dtype.value_or(int32));
 | 
			
		||||
    }
 | 
			
		||||
    case pyfloat: {
 | 
			
		||||
      std::vector<float> vals;
 | 
			
		||||
      fill_vector(pl, vals);
 | 
			
		||||
      return array(vals.begin(), shape, dtype.value_or(float32));
 | 
			
		||||
    }
 | 
			
		||||
    case pycomplex: {
 | 
			
		||||
      std::vector<std::complex<float>> vals;
 | 
			
		||||
      fill_vector(pl, vals);
 | 
			
		||||
      return array(
 | 
			
		||||
          reinterpret_cast<complex64_t*>(vals.data()),
 | 
			
		||||
          shape,
 | 
			
		||||
          dtype.value_or(complex64));
 | 
			
		||||
    }
 | 
			
		||||
  // `pl` contains mlx arrays
 | 
			
		||||
  std::vector<array> arrays;
 | 
			
		||||
  for (auto l : pl) {
 | 
			
		||||
    arrays.push_back(create_array(py::cast<array_init_type>(l), dtype));
 | 
			
		||||
  }
 | 
			
		||||
  return stack(arrays);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
///////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
@@ -419,6 +483,29 @@ array np_array_to_mlx(py::array np_array, std::optional<Dtype> dtype) {
 | 
			
		||||
// Module
 | 
			
		||||
///////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
 | 
			
		||||
array create_array(array_init_type v, std::optional<Dtype> t) {
 | 
			
		||||
  if (auto pv = std::get_if<py::bool_>(&v); pv) {
 | 
			
		||||
    return array(py::cast<bool>(*pv), t.value_or(bool_));
 | 
			
		||||
  } else if (auto pv = std::get_if<py::int_>(&v); pv) {
 | 
			
		||||
    return array(py::cast<int>(*pv), t.value_or(int32));
 | 
			
		||||
  } else if (auto pv = std::get_if<py::float_>(&v); pv) {
 | 
			
		||||
    return array(py::cast<float>(*pv), t.value_or(float32));
 | 
			
		||||
  } else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
 | 
			
		||||
    return array(static_cast<complex64_t>(*pv), t.value_or(complex64));
 | 
			
		||||
  } else if (auto pv = std::get_if<py::list>(&v); pv) {
 | 
			
		||||
    return array_from_list(*pv, t);
 | 
			
		||||
  } else if (auto pv = std::get_if<py::tuple>(&v); pv) {
 | 
			
		||||
    return array_from_list(*pv, t);
 | 
			
		||||
  } else if (auto pv = std::get_if<py::array>(&v); pv) {
 | 
			
		||||
    return np_array_to_mlx(*pv, t);
 | 
			
		||||
  } else if (auto pv = std::get_if<py::buffer>(&v); pv) {
 | 
			
		||||
    return np_array_to_mlx(*pv, t);
 | 
			
		||||
  } else {
 | 
			
		||||
    auto arr = to_array_with_accessor(std::get<py::object>(v));
 | 
			
		||||
    return astype(arr, t.value_or(arr.dtype()));
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void init_array(py::module_& m) {
 | 
			
		||||
  // Types
 | 
			
		||||
  py::class_<Dtype>(
 | 
			
		||||
@@ -466,37 +553,8 @@ void init_array(py::module_& m) {
 | 
			
		||||
    options.disable_function_signatures();
 | 
			
		||||
 | 
			
		||||
    array_class.def(
 | 
			
		||||
        py::init([](std::variant<
 | 
			
		||||
                        py::bool_,
 | 
			
		||||
                        py::int_,
 | 
			
		||||
                        py::float_,
 | 
			
		||||
                        std::complex<float>,
 | 
			
		||||
                        py::list,
 | 
			
		||||
                        py::tuple,
 | 
			
		||||
                        py::array,
 | 
			
		||||
                        py::buffer,
 | 
			
		||||
                        py::object> v,
 | 
			
		||||
                    std::optional<Dtype> t) {
 | 
			
		||||
          if (auto pv = std::get_if<py::bool_>(&v); pv) {
 | 
			
		||||
            return array(py::cast<bool>(*pv), t.value_or(bool_));
 | 
			
		||||
          } else if (auto pv = std::get_if<py::int_>(&v); pv) {
 | 
			
		||||
            return array(py::cast<int>(*pv), t.value_or(int32));
 | 
			
		||||
          } else if (auto pv = std::get_if<py::float_>(&v); pv) {
 | 
			
		||||
            return array(py::cast<float>(*pv), t.value_or(float32));
 | 
			
		||||
          } else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
 | 
			
		||||
            return array(static_cast<complex64_t>(*pv), t.value_or(complex64));
 | 
			
		||||
          } else if (auto pv = std::get_if<py::list>(&v); pv) {
 | 
			
		||||
            return array_from_list(*pv, t);
 | 
			
		||||
          } else if (auto pv = std::get_if<py::tuple>(&v); pv) {
 | 
			
		||||
            return array_from_list(*pv, t);
 | 
			
		||||
          } else if (auto pv = std::get_if<py::array>(&v); pv) {
 | 
			
		||||
            return np_array_to_mlx(*pv, t);
 | 
			
		||||
          } else if (auto pv = std::get_if<py::buffer>(&v); pv) {
 | 
			
		||||
            return np_array_to_mlx(*pv, t);
 | 
			
		||||
          } else {
 | 
			
		||||
            auto arr = to_array_with_accessor(std::get<py::object>(v));
 | 
			
		||||
            return astype(arr, t.value_or(arr.dtype()));
 | 
			
		||||
          }
 | 
			
		||||
        py::init([](array_init_type v, std::optional<Dtype> t) {
 | 
			
		||||
          return create_array(v, t);
 | 
			
		||||
        }),
 | 
			
		||||
        "val"_a,
 | 
			
		||||
        "dtype"_a = std::nullopt,
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user