// Copyright © 2024 Apple Inc. #include #include "python/src/convert.h" namespace nanobind { template <> struct ndarray_traits { static constexpr bool is_complex = false; static constexpr bool is_float = true; static constexpr bool is_bool = false; static constexpr bool is_int = false; static constexpr bool is_signed = true; }; template <> struct ndarray_traits { static constexpr bool is_complex = false; static constexpr bool is_float = true; static constexpr bool is_bool = false; static constexpr bool is_int = false; static constexpr bool is_signed = true; }; static constexpr dlpack::dtype bfloat16{4, 16, 1}; }; // namespace nanobind template array nd_array_to_mlx_contiguous( nb::ndarray nd_array, const std::vector& shape, Dtype dtype) { // Make a copy of the numpy buffer // Get buffer ptr pass to array constructor auto data_ptr = nd_array.data(); return array(static_cast(data_ptr), shape, dtype); } array nd_array_to_mlx( nb::ndarray nd_array, std::optional dtype) { // Compute the shape and size std::vector shape; for (int i = 0; i < nd_array.ndim(); i++) { shape.push_back(nd_array.shape(i)); } auto type = nd_array.dtype(); // Copy data and make array if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(bool_)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(uint8)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(uint16)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(uint32)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(uint64)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(int8)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(int16)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(int32)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(int64)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(float16)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(bfloat16)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(float32)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(float32)); } else if (type == nb::dtype>()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(complex64)); } else if (type == nb::dtype>()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(complex64)); } else { throw std::invalid_argument("Cannot convert numpy array to mlx array."); } } template nb::ndarray mlx_to_nd_array( array a, std::optional t = {}) { // Eval if not already evaled if (!a.is_evaled()) { nb::gil_scoped_release nogil; a.eval(); } std::vector shape(a.shape().begin(), a.shape().end()); std::vector strides(a.strides().begin(), a.strides().end()); return nb::ndarray( a.data(), a.ndim(), shape.data(), nb::handle(), strides.data(), t.value_or(nb::dtype())); } template nb::ndarray mlx_to_nd_array(const array& a) { switch (a.dtype()) { case bool_: return mlx_to_nd_array(a); case uint8: return mlx_to_nd_array(a); case uint16: return mlx_to_nd_array(a); case uint32: return mlx_to_nd_array(a); case uint64: return mlx_to_nd_array(a); case int8: return mlx_to_nd_array(a); case int16: return mlx_to_nd_array(a); case int32: return mlx_to_nd_array(a); case int64: return mlx_to_nd_array(a); case float16: return mlx_to_nd_array(a); case bfloat16: return mlx_to_nd_array(a, nb::bfloat16); case float32: return mlx_to_nd_array(a); case complex64: return mlx_to_nd_array>(a); } } nb::ndarray mlx_to_np_array(const array& a) { return mlx_to_nd_array(a); }