// Copyright © 2023-2024 Apple Inc. #pragma once #include #include #include "mlx/array.h" namespace py = pybind11; using namespace mlx::core; void tree_visit(py::object tree, std::function visitor); py::object tree_map( const std::vector& trees, std::function&)> transform); py::object tree_map( py::object tree, std::function transform); void tree_visit_update( py::object tree, std::function visitor); /** * Fill a pytree (recursive dict or list of dict or list) in place with the * given arrays. */ void tree_fill(py::object& tree, const std::vector& values); /** * Replace all the arrays from the src values with the dst values in the * tree. */ void tree_replace( py::object& tree, const std::vector& src, const std::vector& dst); /** * Flatten a tree into a vector of arrays. If strict is true, then the * function will throw if the tree contains a leaf which is not an array. */ std::vector tree_flatten(py::object tree, bool strict = true); /** * Unflatten a tree from a vector of arrays. */ py::object tree_unflatten( py::object tree, const std::vector& values, int index = 0); std::pair, py::object> tree_flatten_with_structure( py::object tree, bool strict = true); py::object tree_unflatten_from_structure( py::object structure, const std::vector& values, int index = 0);