// Copyright © 2024 Apple Inc. #pragma once #include #include #include #include "mlx/array.h" #include "mlx/ops.h" namespace nb = nanobind; using namespace mlx::core; using ArrayInitType = std::variant< nb::bool_, nb::int_, nb::float_, // Must be above ndarray array, // Must be above complex nb::ndarray, std::complex, nb::list, nb::tuple, nb::object>; array nd_array_to_mlx( nb::ndarray nd_array, std::optional dtype); nb::ndarray mlx_to_np_array(const array& a); nb::ndarray<> mlx_to_dlpack(const array& a); nb::object to_scalar(array& a); nb::object tolist(array& a); array create_array(ArrayInitType v, std::optional t); array array_from_list(nb::list pl, std::optional dtype); array array_from_list(nb::tuple pl, std::optional dtype);