Limit grad recursion depth by not recursing through non-grad inputs (#1764)

* limit grad recursion depth

* add grad of module test
This commit is contained in:
Awni Hannun
2025-01-14 14:33:18 -08:00
committed by GitHub
parent 5cc5201914
commit 33421c1dd3
6 changed files with 136 additions and 100 deletions

View File

@@ -10,7 +10,7 @@ namespace nb = nanobind;
void tree_visit(
const std::vector<nb::object>& trees,
std::function<void(const std::vector<nb::object>&)> visitor);
void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor);
void tree_visit(nb::handle tree, std::function<void(nb::handle)> visitor);
nb::object tree_map(
const std::vector<nb::object>& trees,
@@ -42,7 +42,7 @@ void tree_replace(
* Flatten a tree into a vector of arrays. If strict is true, then the
* function will throw if the tree contains a leaf which is not an array.
*/
std::vector<mx::array> tree_flatten(nb::object tree, bool strict = true);
std::vector<mx::array> tree_flatten(nb::handle tree, bool strict = true);
/**
* Unflatten a tree from a vector of arrays.