// Copyright © 2023-2024 Apple Inc. #pragma once #include #include "mlx/array.h" namespace mx = mlx::core; namespace nb = nanobind; void tree_visit( const std::vector& trees, std::function&)> visitor); void tree_visit(nb::handle tree, std::function visitor); nb::object tree_map( const std::vector& trees, std::function&)> transform); nb::object tree_map( nb::object tree, std::function transform); void tree_visit_update( nb::object tree, std::function visitor); /** * Fill a pytree (recursive dict or list of dict or list) in place with the * given arrays. */ void tree_fill(nb::object& tree, const std::vector& values); /** * Replace all the arrays from the src values with the dst values in the * tree. */ void tree_replace( nb::object& tree, const std::vector& src, const std::vector& dst); /** * Flatten a tree into a vector of arrays. If strict is true, then the * function will throw if the tree contains a leaf which is not an array. */ std::vector tree_flatten(nb::handle tree, bool strict = true); /** * Unflatten a tree from a vector of arrays. */ nb::object tree_unflatten( nb::object tree, const std::vector& values, int index = 0); std::pair, nb::object> tree_flatten_with_structure( nb::object tree, bool strict = true); nb::object tree_unflatten_from_structure( nb::object structure, const std::vector& values, int index = 0);