mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 17:38:09 +08:00
Limit grad recursion depth by not recursing through non-grad inputs (#1764)
* limit grad recursion depth * add grad of module test
This commit is contained in:
@@ -146,7 +146,7 @@ void tree_visit(
|
||||
return recurse(trees);
|
||||
}
|
||||
|
||||
void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor) {
|
||||
void tree_visit(nb::handle tree, std::function<void(nb::handle)> visitor) {
|
||||
std::function<void(nb::handle)> recurse;
|
||||
recurse = [&](nb::handle subtree) {
|
||||
if (nb::isinstance<nb::list>(subtree) ||
|
||||
@@ -178,10 +178,11 @@ void tree_visit_update(
|
||||
}
|
||||
return nb::cast<nb::object>(l);
|
||||
} else if (nb::isinstance<nb::tuple>(subtree)) {
|
||||
for (auto item : subtree) {
|
||||
recurse(item);
|
||||
nb::list l(subtree);
|
||||
for (int i = 0; i < l.size(); ++i) {
|
||||
l[i] = recurse(l[i]);
|
||||
}
|
||||
return nb::cast<nb::object>(subtree);
|
||||
return nb::cast<nb::object>(nb::tuple(l));
|
||||
} else if (nb::isinstance<nb::dict>(subtree)) {
|
||||
auto d = nb::cast<nb::dict>(subtree);
|
||||
for (auto item : d) {
|
||||
@@ -224,7 +225,7 @@ void tree_replace(
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<mx::array> tree_flatten(nb::object tree, bool strict /* = true */) {
|
||||
std::vector<mx::array> tree_flatten(nb::handle tree, bool strict /* = true */) {
|
||||
std::vector<mx::array> flat_tree;
|
||||
|
||||
tree_visit(tree, [&](nb::handle obj) {
|
||||
|
Reference in New Issue
Block a user