// Copyright © 2024 Apple Inc. #pragma once #include #include #include #include "mlx/array.h" #include "mlx/ops.h" namespace mx = mlx::core; namespace nb = nanobind; struct ArrayLike { ArrayLike(nb::object obj) : obj(obj) {}; nb::object obj; }; using ArrayInitType = std::variant< nb::bool_, nb::int_, nb::float_, // Must be above ndarray mx::array, // Must be above complex nb::ndarray, std::complex, nb::list, nb::tuple, ArrayLike>; mx::array nd_array_to_mlx( nb::ndarray nd_array, std::optional dtype); nb::ndarray mlx_to_np_array(const mx::array& a); nb::ndarray<> mlx_to_dlpack(const mx::array& a); nb::object to_scalar(mx::array& a); nb::object tolist(mx::array& a); mx::array create_array(ArrayInitType v, std::optional t); mx::array array_from_list(nb::list pl, std::optional dtype); mx::array array_from_list(nb::tuple pl, std::optional dtype);