mlx/python/src/trees.cpp
Awni Hannun 33421c1dd3
Limit grad recursion depth by not recursing through non-grad inputs (#1764)
* limit grad recursion depth

* add grad of module test
2025-01-14 14:33:18 -08:00

304 lines
9.6 KiB
C++

// Copyright © 2023-2024 Apple Inc.
#include "python/src/trees.h"
template <typename T, typename U, typename V>
void validate_subtrees(const std::vector<nb::object>& subtrees) {
int len = nb::cast<T>(subtrees[0]).size();
for (auto& subtree : subtrees) {
if ((nb::isinstance<T>(subtree) && nb::cast<T>(subtree).size() != len) ||
nb::isinstance<U>(subtree) || nb::isinstance<V>(subtree)) {
throw std::invalid_argument(
"[tree_map] Additional input tree is not a valid prefix of the first tree.");
}
}
}
nb::object tree_map(
const std::vector<nb::object>& trees,
std::function<nb::object(const std::vector<nb::object>&)> transform) {
std::function<nb::object(const std::vector<nb::object>&)> recurse;
recurse = [&](const std::vector<nb::object>& subtrees) {
if (nb::isinstance<nb::list>(subtrees[0])) {
nb::list l;
std::vector<nb::object> items(subtrees.size());
validate_subtrees<nb::list, nb::tuple, nb::dict>(subtrees);
for (int i = 0; i < nb::cast<nb::list>(subtrees[0]).size(); ++i) {
for (int j = 0; j < subtrees.size(); ++j) {
if (nb::isinstance<nb::list>(subtrees[j])) {
items[j] = nb::cast<nb::list>(subtrees[j])[i];
} else {
items[j] = subtrees[j];
}
}
l.append(recurse(items));
}
return nb::cast<nb::object>(l);
} else if (nb::isinstance<nb::tuple>(subtrees[0])) {
// Check the rest of the subtrees
std::vector<nb::object> items(subtrees.size());
int len = nb::cast<nb::tuple>(subtrees[0]).size();
nb::list l;
validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees);
for (int i = 0; i < len; ++i) {
for (int j = 0; j < subtrees.size(); ++j) {
if (nb::isinstance<nb::tuple>(subtrees[j])) {
items[j] = nb::cast<nb::tuple>(subtrees[j])[i];
} else {
items[j] = subtrees[j];
}
}
l.append(recurse(items));
}
return nb::cast<nb::object>(nb::tuple(l));
} else if (nb::isinstance<nb::dict>(subtrees[0])) {
std::vector<nb::object> items(subtrees.size());
validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees);
nb::dict d;
for (auto item : nb::cast<nb::dict>(subtrees[0])) {
for (int j = 0; j < subtrees.size(); ++j) {
if (nb::isinstance<nb::dict>(subtrees[j])) {
auto subdict = nb::cast<nb::dict>(subtrees[j]);
if (!subdict.contains(item.first)) {
throw std::invalid_argument(
"[tree_map] Tree is not a valid prefix tree of the first tree.");
}
items[j] = subdict[item.first];
} else {
items[j] = subtrees[j];
}
}
d[item.first] = recurse(items);
}
return nb::cast<nb::object>(d);
} else {
return transform(subtrees);
}
};
return recurse(trees);
}
nb::object tree_map(
nb::object tree,
std::function<nb::object(nb::handle)> transform) {
return tree_map({tree}, [&](std::vector<nb::object> inputs) {
return transform(inputs[0]);
});
}
void tree_visit(
const std::vector<nb::object>& trees,
std::function<void(const std::vector<nb::object>&)> visitor) {
std::function<void(const std::vector<nb::object>&)> recurse;
recurse = [&](const std::vector<nb::object>& subtrees) {
if (nb::isinstance<nb::list>(subtrees[0])) {
std::vector<nb::object> items(subtrees.size());
validate_subtrees<nb::list, nb::tuple, nb::dict>(subtrees);
for (int i = 0; i < nb::cast<nb::list>(subtrees[0]).size(); ++i) {
for (int j = 0; j < subtrees.size(); ++j) {
if (nb::isinstance<nb::list>(subtrees[j])) {
items[j] = nb::cast<nb::list>(subtrees[j])[i];
} else {
items[j] = subtrees[j];
}
}
recurse(items);
}
} else if (nb::isinstance<nb::tuple>(subtrees[0])) {
// Check the rest of the subtrees
std::vector<nb::object> items(subtrees.size());
int len = nb::cast<nb::tuple>(subtrees[0]).size();
validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees);
for (int i = 0; i < len; ++i) {
for (int j = 0; j < subtrees.size(); ++j) {
if (nb::isinstance<nb::tuple>(subtrees[j])) {
items[j] = nb::cast<nb::tuple>(subtrees[j])[i];
} else {
items[j] = subtrees[j];
}
}
recurse(items);
}
} else if (nb::isinstance<nb::dict>(subtrees[0])) {
std::vector<nb::object> items(subtrees.size());
validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees);
for (auto item : nb::cast<nb::dict>(subtrees[0])) {
for (int j = 0; j < subtrees.size(); ++j) {
if (nb::isinstance<nb::dict>(subtrees[j])) {
auto subdict = nb::cast<nb::dict>(subtrees[j]);
if (!subdict.contains(item.first)) {
throw std::invalid_argument(
"[tree_visit] Tree is not a valid prefix tree of the first tree.");
}
items[j] = subdict[item.first];
} else {
items[j] = subtrees[j];
}
}
recurse(items);
}
} else {
visitor(subtrees);
}
};
return recurse(trees);
}
void tree_visit(nb::handle tree, std::function<void(nb::handle)> visitor) {
std::function<void(nb::handle)> recurse;
recurse = [&](nb::handle subtree) {
if (nb::isinstance<nb::list>(subtree) ||
nb::isinstance<nb::tuple>(subtree)) {
for (auto item : subtree) {
recurse(item);
}
} else if (nb::isinstance<nb::dict>(subtree)) {
for (auto item : nb::cast<nb::dict>(subtree)) {
recurse(item.second);
}
} else {
visitor(subtree);
}
};
recurse(tree);
}
void tree_visit_update(
nb::object tree,
std::function<nb::object(nb::handle)> visitor) {
std::function<nb::object(nb::handle)> recurse;
recurse = [&](nb::handle subtree) {
if (nb::isinstance<nb::list>(subtree)) {
auto l = nb::cast<nb::list>(subtree);
for (int i = 0; i < l.size(); ++i) {
l[i] = recurse(l[i]);
}
return nb::cast<nb::object>(l);
} else if (nb::isinstance<nb::tuple>(subtree)) {
nb::list l(subtree);
for (int i = 0; i < l.size(); ++i) {
l[i] = recurse(l[i]);
}
return nb::cast<nb::object>(nb::tuple(l));
} else if (nb::isinstance<nb::dict>(subtree)) {
auto d = nb::cast<nb::dict>(subtree);
for (auto item : d) {
d[item.first] = recurse(item.second);
}
return nb::cast<nb::object>(d);
} else if (nb::isinstance<mx::array>(subtree)) {
return visitor(subtree);
} else {
return nb::cast<nb::object>(subtree);
}
};
recurse(tree);
}
// Fill a pytree (recursive dict or list of dict or list)
// in place with the given arrays
// Non dict or list nodes are ignored
void tree_fill(nb::object& tree, const std::vector<mx::array>& values) {
size_t index = 0;
tree_visit_update(
tree, [&](nb::handle node) { return nb::cast(values[index++]); });
}
// Replace all the arrays from the src values with the dst values in the tree
void tree_replace(
nb::object& tree,
const std::vector<mx::array>& src,
const std::vector<mx::array>& dst) {
std::unordered_map<uintptr_t, mx::array> src_to_dst;
for (int i = 0; i < src.size(); ++i) {
src_to_dst.insert({src[i].id(), dst[i]});
}
tree_visit_update(tree, [&](nb::handle node) {
auto arr = nb::cast<mx::array>(node);
if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) {
return nb::cast(it->second);
}
return nb::cast(arr);
});
}
std::vector<mx::array> tree_flatten(nb::handle tree, bool strict /* = true */) {
std::vector<mx::array> flat_tree;
tree_visit(tree, [&](nb::handle obj) {
if (nb::isinstance<mx::array>(obj)) {
flat_tree.push_back(nb::cast<mx::array>(obj));
} else if (strict) {
throw std::invalid_argument(
"[tree_flatten] The argument should contain only arrays");
}
});
return flat_tree;
}
nb::object tree_unflatten(
nb::object tree,
const std::vector<mx::array>& values,
int index /* = 0 */) {
return tree_map(tree, [&](nb::handle obj) {
if (nb::isinstance<mx::array>(obj)) {
return nb::cast(values[index++]);
} else {
return nb::cast<nb::object>(obj);
}
});
}
nb::object structure_sentinel() {
static nb::object sentinel;
if (sentinel.ptr() == nullptr) {
sentinel = nb::capsule(&sentinel);
// probably not needed but this should make certain that we won't ever
// delete the sentinel
sentinel.inc_ref();
}
return sentinel;
}
std::pair<std::vector<mx::array>, nb::object> tree_flatten_with_structure(
nb::object tree,
bool strict /* = true */) {
auto sentinel = structure_sentinel();
std::vector<mx::array> flat_tree;
auto structure = tree_map(
tree,
[&flat_tree, sentinel = std::move(sentinel), strict](nb::handle obj) {
if (nb::isinstance<mx::array>(obj)) {
flat_tree.push_back(nb::cast<mx::array>(obj));
return sentinel;
} else if (!strict) {
return nb::cast<nb::object>(obj);
} else {
throw std::invalid_argument(
"[tree_flatten] The argument should contain only arrays");
}
});
return {flat_tree, structure};
}
nb::object tree_unflatten_from_structure(
nb::object structure,
const std::vector<mx::array>& values,
int index /* = 0 */) {
auto sentinel = structure_sentinel();
return tree_map(structure, [&](nb::handle obj) {
if (obj.is(sentinel)) {
return nb::cast(values[index++]);
} else {
return nb::cast<nb::object>(obj);
}
});
}