Make sure that arrays are freed when saving (#247)

This commit is contained in:
Angelos Katharopoulos 2023-12-21 14:08:24 -08:00 committed by GitHub
parent b3916cbf2b
commit 2c7df6795e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 8 deletions

View File

@ -286,7 +286,11 @@ class PyFileWriter : public io::Writer {
py::object tell_func_;
};
void mlx_save_helper(py::object file, array a, bool retain_graph) {
void mlx_save_helper(
py::object file,
array a,
std::optional<bool> retain_graph_) {
bool retain_graph = retain_graph_.value_or(a.is_tracer());
if (py::isinstance<py::str>(file)) {
save(py::cast<std::string>(file), a, retain_graph);
return;
@ -351,7 +355,7 @@ void mlx_savez_helper(
auto writer = std::make_shared<PyFileWriter>(py_ostream);
{
py::gil_scoped_release gil;
save(writer, a);
save(writer, a, /*retain_graph=*/a.is_tracer());
}
}

View File

@ -13,9 +13,12 @@ using namespace mlx::core;
using DictOrArray = std::variant<array, std::unordered_map<std::string, array>>;
DictOrArray mlx_load_helper(py::object file, StreamOrDevice s);
void mlx_save_helper(py::object file, array a, bool retain_graph = true);
void mlx_save_helper(
py::object file,
array a,
std::optional<bool> retain_graph = std::nullopt);
void mlx_savez_helper(
py::object file,
py::args args,
const py::kwargs& kwargs,
bool compressed = false);
bool compressed = false);

View File

@ -2857,18 +2857,20 @@ void init_ops(py::module_& m) {
"file"_a,
"arr"_a,
py::pos_only(),
"retain_graph"_a = true,
"retain_graph"_a = std::nullopt,
py::kw_only(),
R"pbdoc(
save(file: str, arr: array, / , retain_graph: bool = True)
save(file: str, arr: array, / , retain_graph: Optional[bool] = None)
Save the array to a binary file in ``.npy`` format.
Args:
file (str): File to which the array is saved
arr (array): Array to be saved.
retain_graph(bool): Optional argument to retain graph
during array evaluation before saving. Default: True
retain_graph (bool, optional): Optional argument to retain graph
during array evaluation before saving. If not provided the graph
is retained if we are during a function transformation. Default:
None
)pbdoc");
m.def(