Removes the retain_graph flag (#385)

* Adds global tracing flag
* Removes retain_graph in favor of is_tracer
This commit is contained in:
Angelos Katharopoulos
2024-01-07 15:16:51 -08:00
committed by GitHub
parent 449b43762e
commit a611b0bc82
22 changed files with 209 additions and 207 deletions

View File

@@ -345,19 +345,15 @@ class PyFileWriter : public io::Writer {
py::object tell_func_;
};
void mlx_save_helper(
py::object file,
array a,
std::optional<bool> retain_graph_) {
bool retain_graph = retain_graph_.value_or(a.is_tracer());
void mlx_save_helper(py::object file, array a) {
if (py::isinstance<py::str>(file)) {
save(py::cast<std::string>(file), a, retain_graph);
save(py::cast<std::string>(file), a);
return;
} else if (is_ostream_object(file)) {
auto writer = std::make_shared<PyFileWriter>(file);
{
py::gil_scoped_release gil;
save(writer, a, retain_graph);
save(writer, a);
}
return;
@@ -414,26 +410,23 @@ void mlx_savez_helper(
auto writer = std::make_shared<PyFileWriter>(py_ostream);
{
py::gil_scoped_release gil;
save(writer, a, /*retain_graph=*/a.is_tracer());
save(writer, a);
}
}
return;
}
void mlx_save_safetensor_helper(
py::object file,
py::dict d,
std::optional<bool> retain_graph) {
void mlx_save_safetensor_helper(py::object file, py::dict d) {
auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
if (py::isinstance<py::str>(file)) {
save_safetensors(py::cast<std::string>(file), arrays_map, retain_graph);
save_safetensors(py::cast<std::string>(file), arrays_map);
return;
} else if (is_ostream_object(file)) {
auto writer = std::make_shared<PyFileWriter>(file);
{
py::gil_scoped_release gil;
save_safetensors(writer, arrays_map, retain_graph);
save_safetensors(writer, arrays_map);
}
return;