mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-15 04:51:13 +08:00
Make sure that arrays are freed when saving (#247)
This commit is contained in:
parent
b3916cbf2b
commit
2c7df6795e
@ -286,7 +286,11 @@ class PyFileWriter : public io::Writer {
|
|||||||
py::object tell_func_;
|
py::object tell_func_;
|
||||||
};
|
};
|
||||||
|
|
||||||
void mlx_save_helper(py::object file, array a, bool retain_graph) {
|
void mlx_save_helper(
|
||||||
|
py::object file,
|
||||||
|
array a,
|
||||||
|
std::optional<bool> retain_graph_) {
|
||||||
|
bool retain_graph = retain_graph_.value_or(a.is_tracer());
|
||||||
if (py::isinstance<py::str>(file)) {
|
if (py::isinstance<py::str>(file)) {
|
||||||
save(py::cast<std::string>(file), a, retain_graph);
|
save(py::cast<std::string>(file), a, retain_graph);
|
||||||
return;
|
return;
|
||||||
@ -351,7 +355,7 @@ void mlx_savez_helper(
|
|||||||
auto writer = std::make_shared<PyFileWriter>(py_ostream);
|
auto writer = std::make_shared<PyFileWriter>(py_ostream);
|
||||||
{
|
{
|
||||||
py::gil_scoped_release gil;
|
py::gil_scoped_release gil;
|
||||||
save(writer, a);
|
save(writer, a, /*retain_graph=*/a.is_tracer());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -13,9 +13,12 @@ using namespace mlx::core;
|
|||||||
using DictOrArray = std::variant<array, std::unordered_map<std::string, array>>;
|
using DictOrArray = std::variant<array, std::unordered_map<std::string, array>>;
|
||||||
|
|
||||||
DictOrArray mlx_load_helper(py::object file, StreamOrDevice s);
|
DictOrArray mlx_load_helper(py::object file, StreamOrDevice s);
|
||||||
void mlx_save_helper(py::object file, array a, bool retain_graph = true);
|
void mlx_save_helper(
|
||||||
|
py::object file,
|
||||||
|
array a,
|
||||||
|
std::optional<bool> retain_graph = std::nullopt);
|
||||||
void mlx_savez_helper(
|
void mlx_savez_helper(
|
||||||
py::object file,
|
py::object file,
|
||||||
py::args args,
|
py::args args,
|
||||||
const py::kwargs& kwargs,
|
const py::kwargs& kwargs,
|
||||||
bool compressed = false);
|
bool compressed = false);
|
||||||
|
@ -2857,18 +2857,20 @@ void init_ops(py::module_& m) {
|
|||||||
"file"_a,
|
"file"_a,
|
||||||
"arr"_a,
|
"arr"_a,
|
||||||
py::pos_only(),
|
py::pos_only(),
|
||||||
"retain_graph"_a = true,
|
"retain_graph"_a = std::nullopt,
|
||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
save(file: str, arr: array, / , retain_graph: bool = True)
|
save(file: str, arr: array, / , retain_graph: Optional[bool] = None)
|
||||||
|
|
||||||
Save the array to a binary file in ``.npy`` format.
|
Save the array to a binary file in ``.npy`` format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file (str): File to which the array is saved
|
file (str): File to which the array is saved
|
||||||
arr (array): Array to be saved.
|
arr (array): Array to be saved.
|
||||||
retain_graph(bool): Optional argument to retain graph
|
retain_graph (bool, optional): Optional argument to retain graph
|
||||||
during array evaluation before saving. Default: True
|
during array evaluation before saving. If not provided the graph
|
||||||
|
is retained if we are during a function transformation. Default:
|
||||||
|
None
|
||||||
|
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
|
Loading…
Reference in New Issue
Block a user