Removes the retain_graph flag (#385)

* Adds global tracing flag
* Removes retain_graph in favor of is_tracer
This commit is contained in:
Angelos Katharopoulos
2024-01-07 15:16:51 -08:00
committed by GitHub
parent 449b43762e
commit a611b0bc82
22 changed files with 209 additions and 207 deletions

View File

@@ -2902,20 +2902,14 @@ void init_ops(py::module_& m) {
&mlx_save_helper,
"file"_a,
"arr"_a,
py::pos_only(),
"retain_graph"_a = std::nullopt,
py::kw_only(),
R"pbdoc(
save(file: str, arr: array, / , retain_graph: Optional[bool] = None)
save(file: str, arr: array)
Save the array to a binary file in ``.npy`` format.
Args:
file (str): File to which the array is saved
arr (array): Array to be saved.
retain_graph (bool, optional): Whether or not to retain the graph
during array evaluation. If left unspecified the graph is retained
only if saving is done in a function transformation. Default: ``None``
)pbdoc");
m.def(
"savez",
@@ -2999,11 +2993,8 @@ void init_ops(py::module_& m) {
&mlx_save_safetensor_helper,
"file"_a,
"arrays"_a,
py::pos_only(),
"retain_graph"_a = std::nullopt,
py::kw_only(),
R"pbdoc(
save_safetensors(file: str, arrays: Dict[str, array], /, retain_graph: Optional[bool] = None)
save_safetensors(file: str, arrays: Dict[str, array])
Save array(s) to a binary file in ``.safetensors`` format.
@@ -3012,9 +3003,6 @@ void init_ops(py::module_& m) {
Args:
file (file, str): File in which the array is saved>
arrays (dict(str, array)): The dictionary of names to arrays to be saved.
retain_graph (bool, optional): Whether or not to retain the graph
during array evaluation. If left unspecified the graph is retained
only if saving is done in a function transformation. Default: ``None``.
)pbdoc");
m.def(
"where",