Limit grad recursion depth by not recursing through non-grad inputs (#1764)

* limit grad recursion depth

* add grad of module test
This commit is contained in:
Awni Hannun
2025-01-14 14:33:18 -08:00
committed by GitHub
parent 5cc5201914
commit 33421c1dd3
6 changed files with 136 additions and 100 deletions

View File

@@ -146,7 +146,7 @@ void tree_visit(
return recurse(trees);
}
void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor) {
void tree_visit(nb::handle tree, std::function<void(nb::handle)> visitor) {
std::function<void(nb::handle)> recurse;
recurse = [&](nb::handle subtree) {
if (nb::isinstance<nb::list>(subtree) ||
@@ -178,10 +178,11 @@ void tree_visit_update(
}
return nb::cast<nb::object>(l);
} else if (nb::isinstance<nb::tuple>(subtree)) {
for (auto item : subtree) {
recurse(item);
nb::list l(subtree);
for (int i = 0; i < l.size(); ++i) {
l[i] = recurse(l[i]);
}
return nb::cast<nb::object>(subtree);
return nb::cast<nb::object>(nb::tuple(l));
} else if (nb::isinstance<nb::dict>(subtree)) {
auto d = nb::cast<nb::dict>(subtree);
for (auto item : d) {
@@ -224,7 +225,7 @@ void tree_replace(
});
}
std::vector<mx::array> tree_flatten(nb::object tree, bool strict /* = true */) {
std::vector<mx::array> tree_flatten(nb::handle tree, bool strict /* = true */) {
std::vector<mx::array> flat_tree;
tree_visit(tree, [&](nb::handle obj) {