diff --git a/python/src/load.cpp b/python/src/load.cpp index 0730d33d6..1a52930b2 100644 --- a/python/src/load.cpp +++ b/python/src/load.cpp @@ -286,7 +286,11 @@ class PyFileWriter : public io::Writer { py::object tell_func_; }; -void mlx_save_helper(py::object file, array a, bool retain_graph) { +void mlx_save_helper( + py::object file, + array a, + std::optional retain_graph_) { + bool retain_graph = retain_graph_.value_or(a.is_tracer()); if (py::isinstance(file)) { save(py::cast(file), a, retain_graph); return; @@ -351,7 +355,7 @@ void mlx_savez_helper( auto writer = std::make_shared(py_ostream); { py::gil_scoped_release gil; - save(writer, a); + save(writer, a, /*retain_graph=*/a.is_tracer()); } } diff --git a/python/src/load.h b/python/src/load.h index bbabef2c9..8f64a64d1 100644 --- a/python/src/load.h +++ b/python/src/load.h @@ -13,9 +13,12 @@ using namespace mlx::core; using DictOrArray = std::variant>; DictOrArray mlx_load_helper(py::object file, StreamOrDevice s); -void mlx_save_helper(py::object file, array a, bool retain_graph = true); +void mlx_save_helper( + py::object file, + array a, + std::optional retain_graph = std::nullopt); void mlx_savez_helper( py::object file, py::args args, const py::kwargs& kwargs, - bool compressed = false); \ No newline at end of file + bool compressed = false); diff --git a/python/src/ops.cpp b/python/src/ops.cpp index f8e4f237a..23a6ec2c6 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2857,18 +2857,20 @@ void init_ops(py::module_& m) { "file"_a, "arr"_a, py::pos_only(), - "retain_graph"_a = true, + "retain_graph"_a = std::nullopt, py::kw_only(), R"pbdoc( - save(file: str, arr: array, / , retain_graph: bool = True) + save(file: str, arr: array, / , retain_graph: Optional[bool] = None) Save the array to a binary file in ``.npy`` format. Args: file (str): File to which the array is saved arr (array): Array to be saved. - retain_graph(bool): Optional argument to retain graph - during array evaluation before saving. Default: True + retain_graph (bool, optional): Optional argument to retain graph + during array evaluation before saving. If not provided the graph + is retained if we are during a function transformation. Default: + None )pbdoc"); m.def(