mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-06 20:20:11 +08:00
Make sure that arrays are freed when saving (#247)
This commit is contained in:
committed by
GitHub
parent
b3916cbf2b
commit
2c7df6795e
@@ -286,7 +286,11 @@ class PyFileWriter : public io::Writer {
|
||||
py::object tell_func_;
|
||||
};
|
||||
|
||||
void mlx_save_helper(py::object file, array a, bool retain_graph) {
|
||||
void mlx_save_helper(
|
||||
py::object file,
|
||||
array a,
|
||||
std::optional<bool> retain_graph_) {
|
||||
bool retain_graph = retain_graph_.value_or(a.is_tracer());
|
||||
if (py::isinstance<py::str>(file)) {
|
||||
save(py::cast<std::string>(file), a, retain_graph);
|
||||
return;
|
||||
@@ -351,7 +355,7 @@ void mlx_savez_helper(
|
||||
auto writer = std::make_shared<PyFileWriter>(py_ostream);
|
||||
{
|
||||
py::gil_scoped_release gil;
|
||||
save(writer, a);
|
||||
save(writer, a, /*retain_graph=*/a.is_tracer());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user