// Copyright © 2024 Apple Inc. #include #include "python/src/convert.h" #include "python/src/utils.h" #include "mlx/utils.h" enum PyScalarT { pybool = 0, pyint = 1, pyfloat = 2, pycomplex = 3, }; namespace nanobind { template <> struct ndarray_traits { static constexpr bool is_complex = false; static constexpr bool is_float = true; static constexpr bool is_bool = false; static constexpr bool is_int = false; static constexpr bool is_signed = true; }; template <> struct ndarray_traits { static constexpr bool is_complex = false; static constexpr bool is_float = true; static constexpr bool is_bool = false; static constexpr bool is_int = false; static constexpr bool is_signed = true; }; static constexpr dlpack::dtype bfloat16{4, 16, 1}; }; // namespace nanobind template array nd_array_to_mlx_contiguous( nb::ndarray nd_array, const std::vector& shape, Dtype dtype) { // Make a copy of the numpy buffer // Get buffer ptr pass to array constructor auto data_ptr = nd_array.data(); return array(static_cast(data_ptr), shape, dtype); } array nd_array_to_mlx( nb::ndarray nd_array, std::optional dtype) { // Compute the shape and size std::vector shape; for (int i = 0; i < nd_array.ndim(); i++) { shape.push_back(check_shape_dim(nd_array.shape(i))); } auto type = nd_array.dtype(); // Copy data and make array if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(bool_)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(uint8)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(uint16)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(uint32)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(uint64)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(int8)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(int16)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(int32)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(int64)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(float16)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(bfloat16)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(float32)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(float32)); } else if (type == nb::dtype>()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(complex64)); } else if (type == nb::dtype>()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(complex64)); } else { throw std::invalid_argument("Cannot convert numpy array to mlx array."); } } template nb::ndarray mlx_to_nd_array_impl( array a, std::optional t = {}) { { nb::gil_scoped_release nogil; a.eval(); } std::vector shape(a.shape().begin(), a.shape().end()); std::vector strides(a.strides().begin(), a.strides().end()); return nb::ndarray( a.data(), a.ndim(), shape.data(), nb::none(), strides.data(), t.value_or(nb::dtype())); } template nb::ndarray mlx_to_nd_array(const array& a) { switch (a.dtype()) { case bool_: return mlx_to_nd_array_impl(a); case uint8: return mlx_to_nd_array_impl(a); case uint16: return mlx_to_nd_array_impl(a); case uint32: return mlx_to_nd_array_impl(a); case uint64: return mlx_to_nd_array_impl(a); case int8: return mlx_to_nd_array_impl(a); case int16: return mlx_to_nd_array_impl(a); case int32: return mlx_to_nd_array_impl(a); case int64: return mlx_to_nd_array_impl(a); case float16: return mlx_to_nd_array_impl(a); case bfloat16: return mlx_to_nd_array_impl(a, nb::bfloat16); case float32: return mlx_to_nd_array_impl(a); case complex64: return mlx_to_nd_array_impl, NDParams...>(a); } } nb::ndarray mlx_to_np_array(const array& a) { return mlx_to_nd_array(a); } nb::ndarray<> mlx_to_dlpack(const array& a) { return mlx_to_nd_array<>(a); } nb::object to_scalar(array& a) { { nb::gil_scoped_release nogil; a.eval(); } switch (a.dtype()) { case bool_: return nb::cast(a.item()); case uint8: return nb::cast(a.item()); case uint16: return nb::cast(a.item()); case uint32: return nb::cast(a.item()); case uint64: return nb::cast(a.item()); case int8: return nb::cast(a.item()); case int16: return nb::cast(a.item()); case int32: return nb::cast(a.item()); case int64: return nb::cast(a.item()); case float16: return nb::cast(static_cast(a.item())); case float32: return nb::cast(a.item()); case bfloat16: return nb::cast(static_cast(a.item())); case complex64: return nb::cast(a.item>()); } } template nb::list to_list(array& a, size_t index, int dim) { nb::list pl; auto stride = a.strides()[dim]; for (int i = 0; i < a.shape(dim); ++i) { if (dim == a.ndim() - 1) { pl.append(static_cast(a.data()[index])); } else { pl.append(to_list(a, index, dim + 1)); } index += stride; } return pl; } nb::object tolist(array& a) { if (a.ndim() == 0) { return to_scalar(a); } { nb::gil_scoped_release nogil; a.eval(); } switch (a.dtype()) { case bool_: return to_list(a, 0, 0); case uint8: return to_list(a, 0, 0); case uint16: return to_list(a, 0, 0); case uint32: return to_list(a, 0, 0); case uint64: return to_list(a, 0, 0); case int8: return to_list(a, 0, 0); case int16: return to_list(a, 0, 0); case int32: return to_list(a, 0, 0); case int64: return to_list(a, 0, 0); case float16: return to_list(a, 0, 0); case float32: return to_list(a, 0, 0); case bfloat16: return to_list(a, 0, 0); case complex64: return to_list>(a, 0, 0); } } template void fill_vector(T list, std::vector& vals) { for (auto l : list) { if (nb::isinstance(l)) { fill_vector(nb::cast(l), vals); } else if (nb::isinstance(*list.begin())) { fill_vector(nb::cast(l), vals); } else { vals.push_back(nb::cast(l)); } } } template PyScalarT validate_shape( T list, const std::vector& shape, int idx, bool& all_python_primitive_elements) { if (idx >= shape.size()) { throw std::invalid_argument("Initialization encountered extra dimension."); } auto s = shape[idx]; if (nb::len(list) != s) { throw std::invalid_argument( "Initialization encountered non-uniform length."); } if (s == 0) { return pyfloat; } PyScalarT type = pybool; for (auto l : list) { PyScalarT t; if (nb::isinstance(l)) { t = validate_shape( nb::cast(l), shape, idx + 1, all_python_primitive_elements); } else if (nb::isinstance(*list.begin())) { t = validate_shape( nb::cast(l), shape, idx + 1, all_python_primitive_elements); } else if (nb::isinstance(l)) { all_python_primitive_elements = false; auto arr = nb::cast(l); if (arr.ndim() + idx + 1 == shape.size() && std::equal( arr.shape().cbegin(), arr.shape().cend(), shape.cbegin() + idx + 1)) { t = pybool; } else { throw std::invalid_argument( "Initialization encountered non-uniform length."); } } else { if (nb::isinstance(l)) { t = pybool; } else if (nb::isinstance(l)) { t = pyint; } else if (nb::isinstance(l)) { t = pyfloat; } else if (PyComplex_Check(l.ptr())) { t = pycomplex; } else { std::ostringstream msg; msg << "Invalid type " << nb::type_name(l.type()).c_str() << " received in array initialization."; throw std::invalid_argument(msg.str()); } if (idx + 1 != shape.size()) { throw std::invalid_argument( "Initialization encountered non-uniform length."); } } type = std::max(type, t); } return type; } template void get_shape(T list, std::vector& shape) { shape.push_back(check_shape_dim(nb::len(list))); if (shape.back() > 0) { auto l = list.begin(); if (nb::isinstance(*l)) { return get_shape(nb::cast(*l), shape); } else if (nb::isinstance(*l)) { return get_shape(nb::cast(*l), shape); } else if (nb::isinstance(*l)) { auto arr = nb::cast(*l); for (int i = 0; i < arr.ndim(); i++) { shape.push_back(check_shape_dim(arr.shape(i))); } return; } } } template array array_from_list_impl( T pl, const PyScalarT& inferred_type, std::optional specified_type, const std::vector& shape) { // Make the array switch (inferred_type) { case pybool: { std::vector vals; fill_vector(pl, vals); return array(vals.begin(), shape, specified_type.value_or(bool_)); } case pyint: { auto dtype = specified_type.value_or(int32); if (dtype == int64) { std::vector vals; fill_vector(pl, vals); return array(vals.begin(), shape, dtype); } else if (dtype == uint64) { std::vector vals; fill_vector(pl, vals); return array(vals.begin(), shape, dtype); } else if (dtype == uint32) { std::vector vals; fill_vector(pl, vals); return array(vals.begin(), shape, dtype); } else if (issubdtype(dtype, inexact)) { std::vector vals; fill_vector(pl, vals); return array(vals.begin(), shape, dtype); } else { std::vector vals; fill_vector(pl, vals); return array(vals.begin(), shape, dtype); } } case pyfloat: { std::vector vals; fill_vector(pl, vals); return array(vals.begin(), shape, specified_type.value_or(float32)); } case pycomplex: { std::vector> vals; fill_vector(pl, vals); return array( reinterpret_cast(vals.data()), shape, specified_type.value_or(complex64)); } default: { std::ostringstream msg; msg << "Should not happen, inferred: " << inferred_type << " on subarray made of only python primitive types."; throw std::runtime_error(msg.str()); } } } template array array_from_list_impl(T pl, std::optional dtype) { // Compute the shape std::vector shape; get_shape(pl, shape); // Validate the shape and type bool all_python_primitive_elements = true; auto type = validate_shape(pl, shape, 0, all_python_primitive_elements); if (all_python_primitive_elements) { // `pl` does not contain mlx arrays return array_from_list_impl(pl, type, dtype, shape); } // `pl` contains mlx arrays std::vector arrays; for (auto l : pl) { arrays.push_back(create_array(nb::cast(l), dtype)); } return stack(arrays); } array array_from_list(nb::list pl, std::optional dtype) { return array_from_list_impl(pl, dtype); } array array_from_list(nb::tuple pl, std::optional dtype) { return array_from_list_impl(pl, dtype); } array create_array(ArrayInitType v, std::optional t) { if (auto pv = std::get_if(&v); pv) { return array(nb::cast(*pv), t.value_or(bool_)); } else if (auto pv = std::get_if(&v); pv) { return array(nb::cast(*pv), t.value_or(int32)); } else if (auto pv = std::get_if(&v); pv) { return array(nb::cast(*pv), t.value_or(float32)); } else if (auto pv = std::get_if>(&v); pv) { return array(static_cast(*pv), t.value_or(complex64)); } else if (auto pv = std::get_if(&v); pv) { return array_from_list(*pv, t); } else if (auto pv = std::get_if(&v); pv) { return array_from_list(*pv, t); } else if (auto pv = std::get_if< nb::ndarray>(&v); pv) { return nd_array_to_mlx(*pv, t); } else if (auto pv = std::get_if(&v); pv) { return astype(*pv, t.value_or((*pv).dtype())); } else { auto arr = to_array_with_accessor(std::get(v)); return astype(arr, t.value_or(arr.dtype())); } }