Removes the retain_graph flag (#385)

* Adds global tracing flag
* Removes retain_graph in favor of is_tracer
This commit is contained in:
Angelos Katharopoulos
2024-01-07 15:16:51 -08:00
committed by GitHub
parent 449b43762e
commit a611b0bc82
22 changed files with 209 additions and 207 deletions

View File

@@ -17,19 +17,13 @@ using DictOrArray = std::variant<array, std::unordered_map<std::string, array>>;
std::unordered_map<std::string, array> mlx_load_safetensor_helper(
py::object file,
StreamOrDevice s);
void mlx_save_safetensor_helper(
py::object file,
py::dict d,
std::optional<bool> retain_graph = std::nullopt);
void mlx_save_safetensor_helper(py::object file, py::dict d);
DictOrArray mlx_load_helper(
py::object file,
std::optional<std::string> format,
StreamOrDevice s);
void mlx_save_helper(
py::object file,
array a,
std::optional<bool> retain_graph = std::nullopt);
void mlx_save_helper(py::object file, array a);
void mlx_savez_helper(
py::object file,
py::args args,