mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-06 12:09:43 +08:00
Removes the retain_graph flag (#385)
* Adds global tracing flag * Removes retain_graph in favor of is_tracer
This commit is contained in:
committed by
GitHub
parent
449b43762e
commit
a611b0bc82
@@ -345,19 +345,15 @@ class PyFileWriter : public io::Writer {
|
||||
py::object tell_func_;
|
||||
};
|
||||
|
||||
void mlx_save_helper(
|
||||
py::object file,
|
||||
array a,
|
||||
std::optional<bool> retain_graph_) {
|
||||
bool retain_graph = retain_graph_.value_or(a.is_tracer());
|
||||
void mlx_save_helper(py::object file, array a) {
|
||||
if (py::isinstance<py::str>(file)) {
|
||||
save(py::cast<std::string>(file), a, retain_graph);
|
||||
save(py::cast<std::string>(file), a);
|
||||
return;
|
||||
} else if (is_ostream_object(file)) {
|
||||
auto writer = std::make_shared<PyFileWriter>(file);
|
||||
{
|
||||
py::gil_scoped_release gil;
|
||||
save(writer, a, retain_graph);
|
||||
save(writer, a);
|
||||
}
|
||||
|
||||
return;
|
||||
@@ -414,26 +410,23 @@ void mlx_savez_helper(
|
||||
auto writer = std::make_shared<PyFileWriter>(py_ostream);
|
||||
{
|
||||
py::gil_scoped_release gil;
|
||||
save(writer, a, /*retain_graph=*/a.is_tracer());
|
||||
save(writer, a);
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
void mlx_save_safetensor_helper(
|
||||
py::object file,
|
||||
py::dict d,
|
||||
std::optional<bool> retain_graph) {
|
||||
void mlx_save_safetensor_helper(py::object file, py::dict d) {
|
||||
auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
|
||||
if (py::isinstance<py::str>(file)) {
|
||||
save_safetensors(py::cast<std::string>(file), arrays_map, retain_graph);
|
||||
save_safetensors(py::cast<std::string>(file), arrays_map);
|
||||
return;
|
||||
} else if (is_ostream_object(file)) {
|
||||
auto writer = std::make_shared<PyFileWriter>(file);
|
||||
{
|
||||
py::gil_scoped_release gil;
|
||||
save_safetensors(writer, arrays_map, retain_graph);
|
||||
save_safetensors(writer, arrays_map);
|
||||
}
|
||||
|
||||
return;
|
||||
|
||||
Reference in New Issue
Block a user