mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-08 13:28:15 +08:00
Limit grad recursion depth by not recursing through non-grad inputs (#1764)
* limit grad recursion depth * add grad of module test
This commit is contained in:
@@ -10,7 +10,7 @@ namespace nb = nanobind;
|
||||
void tree_visit(
|
||||
const std::vector<nb::object>& trees,
|
||||
std::function<void(const std::vector<nb::object>&)> visitor);
|
||||
void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor);
|
||||
void tree_visit(nb::handle tree, std::function<void(nb::handle)> visitor);
|
||||
|
||||
nb::object tree_map(
|
||||
const std::vector<nb::object>& trees,
|
||||
@@ -42,7 +42,7 @@ void tree_replace(
|
||||
* Flatten a tree into a vector of arrays. If strict is true, then the
|
||||
* function will throw if the tree contains a leaf which is not an array.
|
||||
*/
|
||||
std::vector<mx::array> tree_flatten(nb::object tree, bool strict = true);
|
||||
std::vector<mx::array> tree_flatten(nb::handle tree, bool strict = true);
|
||||
|
||||
/**
|
||||
* Unflatten a tree from a vector of arrays.
|
||||
|
||||
Reference in New Issue
Block a user