Make sure that arrays are freed when saving (#247)

This commit is contained in:
Angelos Katharopoulos
2023-12-21 14:08:24 -08:00
committed by GitHub
parent b3916cbf2b
commit 2c7df6795e
3 changed files with 17 additions and 8 deletions

View File

@@ -286,7 +286,11 @@ class PyFileWriter : public io::Writer {
py::object tell_func_;
};
void mlx_save_helper(py::object file, array a, bool retain_graph) {
void mlx_save_helper(
py::object file,
array a,
std::optional<bool> retain_graph_) {
bool retain_graph = retain_graph_.value_or(a.is_tracer());
if (py::isinstance<py::str>(file)) {
save(py::cast<std::string>(file), a, retain_graph);
return;
@@ -351,7 +355,7 @@ void mlx_savez_helper(
auto writer = std::make_shared<PyFileWriter>(py_ostream);
{
py::gil_scoped_release gil;
save(writer, a);
save(writer, a, /*retain_graph=*/a.is_tracer());
}
}