// Copyright © 2024 Apple Inc. #include #include "python/src/convert.h" #include "python/src/utils.h" #include "mlx/utils.h" enum PyScalarT { pybool = 0, pyint = 1, pyfloat = 2, pycomplex = 3, }; namespace nanobind { template <> struct ndarray_traits { static constexpr bool is_complex = false; static constexpr bool is_float = true; static constexpr bool is_bool = false; static constexpr bool is_int = false; static constexpr bool is_signed = true; }; static constexpr dlpack::dtype bfloat16{4, 16, 1}; }; // namespace nanobind int check_shape_dim(int64_t dim) { if (dim > std::numeric_limits::max()) { throw std::invalid_argument( "Shape dimension falls outside supported `int` range."); } return static_cast(dim); } template mx::array nd_array_to_mlx_contiguous( nb::ndarray nd_array, const mx::Shape& shape, mx::Dtype dtype) { // Make a copy of the numpy buffer // Get buffer ptr pass to array constructor auto data_ptr = nd_array.data(); return mx::array(static_cast(data_ptr), shape, dtype); } mx::array nd_array_to_mlx( nb::ndarray nd_array, std::optional dtype) { // Compute the shape and size mx::Shape shape; for (int i = 0; i < nd_array.ndim(); i++) { shape.push_back(check_shape_dim(nd_array.shape(i))); } auto type = nd_array.dtype(); // Copy data and make array if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(mx::bool_)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(mx::uint8)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(mx::uint16)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(mx::uint32)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(mx::uint64)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(mx::int8)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(mx::int16)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(mx::int32)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(mx::int64)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(mx::float16)); } else if (type == nb::bfloat16) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(mx::bfloat16)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(mx::float32)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(mx::float32)); } else if (type == nb::dtype>()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(mx::complex64)); } else if (type == nb::dtype>()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(mx::complex64)); } else { throw std::invalid_argument("Cannot convert numpy array to mlx array."); } } template nb::ndarray mlx_to_nd_array_impl( mx::array a, std::optional t = {}) { { nb::gil_scoped_release nogil; a.eval(); } std::vector shape(a.shape().begin(), a.shape().end()); return nb::ndarray( a.data(), a.ndim(), shape.data(), /* owner= */ nb::none(), a.strides().data(), t.value_or(nb::dtype())); } template nb::ndarray mlx_to_nd_array(const mx::array& a) { switch (a.dtype()) { case mx::bool_: return mlx_to_nd_array_impl(a); case mx::uint8: return mlx_to_nd_array_impl(a); case mx::uint16: return mlx_to_nd_array_impl(a); case mx::uint32: return mlx_to_nd_array_impl(a); case mx::uint64: return mlx_to_nd_array_impl(a); case mx::int8: return mlx_to_nd_array_impl(a); case mx::int16: return mlx_to_nd_array_impl(a); case mx::int32: return mlx_to_nd_array_impl(a); case mx::int64: return mlx_to_nd_array_impl(a); case mx::float16: return mlx_to_nd_array_impl(a); case mx::bfloat16: throw nb::type_error("bfloat16 arrays cannot be converted to NumPy."); case mx::float32: return mlx_to_nd_array_impl(a); case mx::float64: return mlx_to_nd_array_impl(a); case mx::complex64: return mlx_to_nd_array_impl, NDParams...>(a); default: throw nb::type_error("type cannot be converted to NumPy."); } } nb::ndarray mlx_to_np_array(const mx::array& a) { return mlx_to_nd_array(a); } nb::ndarray<> mlx_to_dlpack(const mx::array& a) { return mlx_to_nd_array<>(a); } nb::object to_scalar(mx::array& a) { if (a.size() != 1) { throw std::invalid_argument( "[convert] Only length-1 arrays can be converted to Python scalars."); } { nb::gil_scoped_release nogil; a.eval(); } switch (a.dtype()) { case mx::bool_: return nb::cast(a.item()); case mx::uint8: return nb::cast(a.item()); case mx::uint16: return nb::cast(a.item()); case mx::uint32: return nb::cast(a.item()); case mx::uint64: return nb::cast(a.item()); case mx::int8: return nb::cast(a.item()); case mx::int16: return nb::cast(a.item()); case mx::int32: return nb::cast(a.item()); case mx::int64: return nb::cast(a.item()); case mx::float16: return nb::cast(static_cast(a.item())); case mx::float32: return nb::cast(a.item()); case mx::bfloat16: return nb::cast(static_cast(a.item())); case mx::complex64: return nb::cast(a.item>()); default: throw nb::type_error("type cannot be converted to Python scalar."); } } template nb::list to_list(mx::array& a, size_t index, int dim) { nb::list pl; auto stride = a.strides()[dim]; for (int i = 0; i < a.shape(dim); ++i) { if (dim == a.ndim() - 1) { pl.append(static_cast(a.data()[index])); } else { pl.append(to_list(a, index, dim + 1)); } index += stride; } return pl; } nb::object tolist(mx::array& a) { if (a.ndim() == 0) { return to_scalar(a); } { nb::gil_scoped_release nogil; a.eval(); } switch (a.dtype()) { case mx::bool_: return to_list(a, 0, 0); case mx::uint8: return to_list(a, 0, 0); case mx::uint16: return to_list(a, 0, 0); case mx::uint32: return to_list(a, 0, 0); case mx::uint64: return to_list(a, 0, 0); case mx::int8: return to_list(a, 0, 0); case mx::int16: return to_list(a, 0, 0); case mx::int32: return to_list(a, 0, 0); case mx::int64: return to_list(a, 0, 0); case mx::float16: return to_list(a, 0, 0); case mx::float32: return to_list(a, 0, 0); case mx::bfloat16: return to_list(a, 0, 0); case mx::complex64: return to_list>(a, 0, 0); default: throw nb::type_error("data type cannot be converted to Python list."); } } template void fill_vector(T list, std::vector& vals) { for (auto l : list) { if (nb::isinstance(l)) { fill_vector(nb::cast(l), vals); } else if (nb::isinstance(*list.begin())) { fill_vector(nb::cast(l), vals); } else { vals.push_back(nb::cast(l)); } } } template PyScalarT validate_shape( T list, const mx::Shape& shape, int idx, bool& all_python_primitive_elements) { if (idx >= shape.size()) { throw std::invalid_argument("Initialization encountered extra dimension."); } auto s = shape[idx]; if (nb::len(list) != s) { throw std::invalid_argument( "Initialization encountered non-uniform length."); } if (s == 0) { return pyfloat; } PyScalarT type = pybool; for (auto l : list) { PyScalarT t; if (nb::isinstance(l)) { t = validate_shape( nb::cast(l), shape, idx + 1, all_python_primitive_elements); } else if (nb::isinstance(*list.begin())) { t = validate_shape( nb::cast(l), shape, idx + 1, all_python_primitive_elements); } else if (nb::isinstance(l)) { all_python_primitive_elements = false; auto arr = nb::cast(l); if (arr.ndim() + idx + 1 == shape.size() && std::equal( arr.shape().cbegin(), arr.shape().cend(), shape.cbegin() + idx + 1)) { t = pybool; } else { throw std::invalid_argument( "Initialization encountered non-uniform length."); } } else { if (nb::isinstance(l)) { t = pybool; } else if (nb::isinstance(l)) { t = pyint; } else if (nb::isinstance(l)) { t = pyfloat; } else if (PyComplex_Check(l.ptr())) { t = pycomplex; } else { std::ostringstream msg; msg << "Invalid type " << nb::type_name(l.type()).c_str() << " received in array initialization."; throw std::invalid_argument(msg.str()); } if (idx + 1 != shape.size()) { throw std::invalid_argument( "Initialization encountered non-uniform length."); } } type = std::max(type, t); } return type; } template void get_shape(T list, mx::Shape& shape) { shape.push_back(check_shape_dim(nb::len(list))); if (shape.back() > 0) { auto l = list.begin(); if (nb::isinstance(*l)) { return get_shape(nb::cast(*l), shape); } else if (nb::isinstance(*l)) { return get_shape(nb::cast(*l), shape); } else if (nb::isinstance(*l)) { auto arr = nb::cast(*l); for (int i = 0; i < arr.ndim(); i++) { shape.push_back(arr.shape(i)); } return; } } } template mx::array array_from_list_impl( T pl, const PyScalarT& inferred_type, std::optional specified_type, const mx::Shape& shape) { // Make the array switch (inferred_type) { case pybool: { std::vector vals; fill_vector(pl, vals); return mx::array(vals.begin(), shape, specified_type.value_or(mx::bool_)); } case pyint: { auto dtype = specified_type.value_or(mx::int32); if (dtype == mx::int64) { std::vector vals; fill_vector(pl, vals); return mx::array(vals.begin(), shape, dtype); } else if (dtype == mx::uint64) { std::vector vals; fill_vector(pl, vals); return mx::array(vals.begin(), shape, dtype); } else if (dtype == mx::uint32) { std::vector vals; fill_vector(pl, vals); return mx::array(vals.begin(), shape, dtype); } else if (mx::issubdtype(dtype, mx::inexact)) { std::vector vals; fill_vector(pl, vals); return mx::array(vals.begin(), shape, dtype); } else { std::vector vals; fill_vector(pl, vals); return mx::array(vals.begin(), shape, dtype); } } case pyfloat: { std::vector vals; fill_vector(pl, vals); return mx::array( vals.begin(), shape, specified_type.value_or(mx::float32)); } case pycomplex: { std::vector> vals; fill_vector(pl, vals); return mx::array( reinterpret_cast(vals.data()), shape, specified_type.value_or(mx::complex64)); } default: { std::ostringstream msg; msg << "Should not happen, inferred: " << inferred_type << " on subarray made of only python primitive types."; throw std::runtime_error(msg.str()); } } } template mx::array array_from_list_impl(T pl, std::optional dtype) { // Compute the shape mx::Shape shape; get_shape(pl, shape); // Validate the shape and type bool all_python_primitive_elements = true; auto type = validate_shape(pl, shape, 0, all_python_primitive_elements); if (all_python_primitive_elements) { // `pl` does not contain mlx arrays return array_from_list_impl(pl, type, dtype, shape); } // `pl` contains mlx arrays std::vector arrays; for (auto l : pl) { arrays.push_back(create_array(nb::cast(l), dtype)); } return mx::stack(arrays); } mx::array array_from_list(nb::list pl, std::optional dtype) { return array_from_list_impl(pl, dtype); } mx::array array_from_list(nb::tuple pl, std::optional dtype) { return array_from_list_impl(pl, dtype); } mx::array create_array(ArrayInitType v, std::optional t) { if (auto pv = std::get_if(&v); pv) { return mx::array(nb::cast(*pv), t.value_or(mx::bool_)); } else if (auto pv = std::get_if(&v); pv) { auto val = nb::cast(*pv); auto default_type = (val > std::numeric_limits::max() || val < std::numeric_limits::min()) ? mx::int64 : mx::int32; return mx::array(val, t.value_or(default_type)); } else if (auto pv = std::get_if(&v); pv) { return mx::array(nb::cast(*pv), t.value_or(mx::float32)); } else if (auto pv = std::get_if>(&v); pv) { return mx::array( static_cast(*pv), t.value_or(mx::complex64)); } else if (auto pv = std::get_if(&v); pv) { return array_from_list(*pv, t); } else if (auto pv = std::get_if(&v); pv) { return array_from_list(*pv, t); } else if (auto pv = std::get_if< nb::ndarray>(&v); pv) { return nd_array_to_mlx(*pv, t); } else if (auto pv = std::get_if(&v); pv) { return mx::astype(*pv, t.value_or((*pv).dtype())); } else { auto arr = to_array_with_accessor(std::get(v).obj); return mx::astype(arr, t.value_or(arr.dtype())); } }