// Copyright © 2023-2024 Apple Inc. #include "python/src/trees.h" template void validate_subtrees(const std::vector& subtrees) { int len = nb::cast(subtrees[0]).size(); for (auto& subtree : subtrees) { if ((nb::isinstance(subtree) && nb::cast(subtree).size() != len) || nb::isinstance(subtree) || nb::isinstance(subtree)) { throw std::invalid_argument( "[tree_map] Additional input tree is not a valid prefix of the first tree."); } } } nb::object tree_map( const std::vector& trees, std::function&)> transform) { std::function&)> recurse; recurse = [&](const std::vector& subtrees) { if (nb::isinstance(subtrees[0])) { nb::list l; std::vector items(subtrees.size()); validate_subtrees(subtrees); for (int i = 0; i < nb::cast(subtrees[0]).size(); ++i) { for (int j = 0; j < subtrees.size(); ++j) { if (nb::isinstance(subtrees[j])) { items[j] = nb::cast(subtrees[j])[i]; } else { items[j] = subtrees[j]; } } l.append(recurse(items)); } return nb::cast(l); } else if (nb::isinstance(subtrees[0])) { // Check the rest of the subtrees std::vector items(subtrees.size()); int len = nb::cast(subtrees[0]).size(); nb::list l; validate_subtrees(subtrees); for (int i = 0; i < len; ++i) { for (int j = 0; j < subtrees.size(); ++j) { if (nb::isinstance(subtrees[j])) { items[j] = nb::cast(subtrees[j])[i]; } else { items[j] = subtrees[j]; } } l.append(recurse(items)); } return nb::cast(nb::tuple(l)); } else if (nb::isinstance(subtrees[0])) { std::vector items(subtrees.size()); validate_subtrees(subtrees); nb::dict d; for (auto item : nb::cast(subtrees[0])) { for (int j = 0; j < subtrees.size(); ++j) { if (nb::isinstance(subtrees[j])) { auto subdict = nb::cast(subtrees[j]); if (!subdict.contains(item.first)) { throw std::invalid_argument( "[tree_map] Tree is not a valid prefix tree of the first tree."); } items[j] = subdict[item.first]; } else { items[j] = subtrees[j]; } } d[item.first] = recurse(items); } return nb::cast(d); } else { return transform(subtrees); } }; return recurse(trees); } nb::object tree_map( nb::object tree, std::function transform) { return tree_map({tree}, [&](std::vector inputs) { return transform(inputs[0]); }); } void tree_visit( const std::vector& trees, std::function&)> visitor) { std::function&)> recurse; recurse = [&](const std::vector& subtrees) { if (nb::isinstance(subtrees[0])) { std::vector items(subtrees.size()); validate_subtrees(subtrees); for (int i = 0; i < nb::cast(subtrees[0]).size(); ++i) { for (int j = 0; j < subtrees.size(); ++j) { if (nb::isinstance(subtrees[j])) { items[j] = nb::cast(subtrees[j])[i]; } else { items[j] = subtrees[j]; } } recurse(items); } } else if (nb::isinstance(subtrees[0])) { // Check the rest of the subtrees std::vector items(subtrees.size()); int len = nb::cast(subtrees[0]).size(); validate_subtrees(subtrees); for (int i = 0; i < len; ++i) { for (int j = 0; j < subtrees.size(); ++j) { if (nb::isinstance(subtrees[j])) { items[j] = nb::cast(subtrees[j])[i]; } else { items[j] = subtrees[j]; } } recurse(items); } } else if (nb::isinstance(subtrees[0])) { std::vector items(subtrees.size()); validate_subtrees(subtrees); for (auto item : nb::cast(subtrees[0])) { for (int j = 0; j < subtrees.size(); ++j) { if (nb::isinstance(subtrees[j])) { auto subdict = nb::cast(subtrees[j]); if (!subdict.contains(item.first)) { throw std::invalid_argument( "[tree_visit] Tree is not a valid prefix tree of the first tree."); } items[j] = subdict[item.first]; } else { items[j] = subtrees[j]; } } recurse(items); } } else { visitor(subtrees); } }; return recurse(trees); } void tree_visit(nb::handle tree, std::function visitor) { std::function recurse; recurse = [&](nb::handle subtree) { if (nb::isinstance(subtree) || nb::isinstance(subtree)) { for (auto item : subtree) { recurse(item); } } else if (nb::isinstance(subtree)) { for (auto item : nb::cast(subtree)) { recurse(item.second); } } else { visitor(subtree); } }; recurse(tree); } void tree_visit_update( nb::object tree, std::function visitor) { std::function recurse; recurse = [&](nb::handle subtree) { if (nb::isinstance(subtree)) { auto l = nb::cast(subtree); for (int i = 0; i < l.size(); ++i) { l[i] = recurse(l[i]); } return nb::cast(l); } else if (nb::isinstance(subtree)) { nb::list l(subtree); for (int i = 0; i < l.size(); ++i) { l[i] = recurse(l[i]); } return nb::cast(nb::tuple(l)); } else if (nb::isinstance(subtree)) { auto d = nb::cast(subtree); for (auto item : d) { d[item.first] = recurse(item.second); } return nb::cast(d); } else if (nb::isinstance(subtree)) { return visitor(subtree); } else { return nb::cast(subtree); } }; recurse(tree); } // Fill a pytree (recursive dict or list of dict or list) // in place with the given arrays // Non dict or list nodes are ignored void tree_fill(nb::object& tree, const std::vector& values) { size_t index = 0; tree_visit_update( tree, [&](nb::handle node) { return nb::cast(values[index++]); }); } // Replace all the arrays from the src values with the dst values in the tree void tree_replace( nb::object& tree, const std::vector& src, const std::vector& dst) { std::unordered_map src_to_dst; for (int i = 0; i < src.size(); ++i) { src_to_dst.insert({src[i].id(), dst[i]}); } tree_visit_update(tree, [&](nb::handle node) { auto arr = nb::cast(node); if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) { return nb::cast(it->second); } return nb::cast(arr); }); } std::vector tree_flatten(nb::handle tree, bool strict /* = true */) { std::vector flat_tree; tree_visit(tree, [&](nb::handle obj) { if (nb::isinstance(obj)) { flat_tree.push_back(nb::cast(obj)); } else if (strict) { throw std::invalid_argument( "[tree_flatten] The argument should contain only arrays"); } }); return flat_tree; } nb::object tree_unflatten( nb::object tree, const std::vector& values, int index /* = 0 */) { return tree_map(tree, [&](nb::handle obj) { if (nb::isinstance(obj)) { return nb::cast(values[index++]); } else { return nb::cast(obj); } }); } nb::object structure_sentinel() { static nb::object sentinel; if (sentinel.ptr() == nullptr) { sentinel = nb::capsule(&sentinel); // probably not needed but this should make certain that we won't ever // delete the sentinel sentinel.inc_ref(); } return sentinel; } std::pair, nb::object> tree_flatten_with_structure( nb::object tree, bool strict /* = true */) { auto sentinel = structure_sentinel(); std::vector flat_tree; auto structure = tree_map( tree, [&flat_tree, sentinel = std::move(sentinel), strict](nb::handle obj) { if (nb::isinstance(obj)) { flat_tree.push_back(nb::cast(obj)); return sentinel; } else if (!strict) { return nb::cast(obj); } else { throw std::invalid_argument( "[tree_flatten] The argument should contain only arrays"); } }); return {flat_tree, structure}; } nb::object tree_unflatten_from_structure( nb::object structure, const std::vector& values, int index /* = 0 */) { auto sentinel = structure_sentinel(); return tree_map(structure, [&](nb::handle obj) { if (obj.is(sentinel)) { return nb::cast(values[index++]); } else { return nb::cast(obj); } }); }