Removes the retain_graph flag (#385)

* Adds global tracing flag
* Removes retain_graph in favor of is_tracer
This commit is contained in:
Angelos Katharopoulos
2024-01-07 15:16:51 -08:00
committed by GitHub
parent 449b43762e
commit a611b0bc82
22 changed files with 209 additions and 207 deletions

View File

@@ -440,11 +440,10 @@ auto py_vmap(
void init_transforms(py::module_& m) {
m.def(
"eval",
[](const py::args& args, bool retain_graph) {
[](const py::args& args) {
std::vector<array> arrays = tree_flatten(args);
eval(arrays, retain_graph);
eval(arrays);
},
"retain_graph"_a = false,
R"pbdoc(
Evaluate an :class:`array` or tree of :class:`array`.
@@ -453,9 +452,6 @@ void init_transforms(py::module_& m) {
or a tree of arrays. If a tree is given the nodes can be a Python
:class:`list`, :class:`tuple` or :class:`dict` but the leafs must all be
an :class:`array`.
retain_graph (bool): Indicate that the graph structure should be
preserved. This option is intended to enable function transforms
which contain control flow based on the value of an array.
)pbdoc");
m.def(
"jvp",