From fdf9d99f0f5695432d6e6aae6a72590c2c44e25f Mon Sep 17 00:00:00 2001 From: dc-dc-dc Date: Wed, 20 Dec 2023 14:39:31 -0500 Subject: [PATCH] add back retain_graph argument --- mlx/ops.h | 6 ++++-- mlx/safetensor.cpp | 10 ++++++---- python/src/load.cpp | 9 ++++++--- python/src/load.h | 5 ++++- python/src/ops.cpp | 5 ++++- 5 files changed, 24 insertions(+), 11 deletions(-) diff --git a/mlx/ops.h b/mlx/ops.h index c0feaa402..f8737d83c 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1067,8 +1067,10 @@ std::unordered_map load_safetensor( void save_safetensor( std::shared_ptr in_stream, - std::unordered_map); + std::unordered_map, + bool retain_graph = true); void save_safetensor( const std::string& file, - std::unordered_map); + std::unordered_map, + bool retain_graph = true); } // namespace mlx::core diff --git a/mlx/safetensor.cpp b/mlx/safetensor.cpp index f511faf80..b198454e0 100644 --- a/mlx/safetensor.cpp +++ b/mlx/safetensor.cpp @@ -126,7 +126,8 @@ std::unordered_map load_safetensor( /** Save array to out stream in .npy format */ void save_safetensor( std::shared_ptr out_stream, - std::unordered_map a) { + std::unordered_map a, + bool retain_graph) { //////////////////////////////////////////////////////// // Check file if (!out_stream->good() || !out_stream->is_open()) { @@ -142,7 +143,7 @@ void save_safetensor( }); size_t offset = 0; for (auto& [key, arr] : a) { - arr.eval(false); + arr.eval(retain_graph); if (arr.nbytes() == 0) { throw std::invalid_argument( "[save_safetensor] cannot serialize an empty array key: " + key); @@ -172,7 +173,8 @@ void save_safetensor( void save_safetensor( const std::string& file_, - std::unordered_map a) { + std::unordered_map a, + bool retain_graph) { // Open and check file std::string file = file_; @@ -182,7 +184,7 @@ void save_safetensor( file += ".safetensors"; // Serialize array - save_safetensor(std::make_shared(file), a); + save_safetensor(std::make_shared(file), a, retain_graph); } } // namespace mlx::core \ No newline at end of file diff --git a/python/src/load.cpp b/python/src/load.cpp index 1246a94b2..492a40b41 100644 --- a/python/src/load.cpp +++ b/python/src/load.cpp @@ -419,16 +419,19 @@ void mlx_savez_helper( return; } -void mlx_save_safetensor_helper(py::object file, py::dict d) { +void mlx_save_safetensor_helper( + py::object file, + py::dict d, + bool retain_graph) { auto arrays_map = d.cast>(); if (py::isinstance(file)) { - save_safetensor(py::cast(file), arrays_map); + save_safetensor(py::cast(file), arrays_map, retain_graph); return; } else if (is_ostream_object(file)) { auto writer = std::make_shared(file); { py::gil_scoped_release gil; - save_safetensor(writer, arrays_map); + save_safetensor(writer, arrays_map, retain_graph); } return; diff --git a/python/src/load.h b/python/src/load.h index 45d64664d..9103678bf 100644 --- a/python/src/load.h +++ b/python/src/load.h @@ -17,7 +17,10 @@ using DictOrArray = std::variant>; std::unordered_map mlx_load_safetensor_helper( py::object file, StreamOrDevice s); -void mlx_save_safetensor_helper(py::object file, py::dict d); +void mlx_save_safetensor_helper( + py::object file, + py::dict d, + bool retain_graph = true); DictOrArray mlx_load_helper(py::object file, StreamOrDevice s); void mlx_save_helper( diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 6455cfc22..23fc3b611 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2953,15 +2953,18 @@ void init_ops(py::module_& m) { "file"_a, "d"_a, py::pos_only(), + "retain_graph"_a = true, py::kw_only(), R"pbdoc( - save_safetensor(file: str, d: Dict[str, array], /, *) + save_safetensor(file: str, d: Dict[str, array], /, retain_graph: bool = True, *) Save array(s) to a binary file in ``.safetensors`` format. Args: file (file, str): File in which the array is saved d (Dict[str, array]): The dict mapping name to array to be saved + retain_graph(bool): Optional argument to retain graph + during array evaluation before saving. Default: True )pbdoc"); m.def( "where",