From ee6ce00aee2bee286118a7278b8ecab6ec61f545 Mon Sep 17 00:00:00 2001 From: dc-dc-dc Date: Fri, 22 Dec 2023 16:19:31 -0500 Subject: [PATCH] docs and made retain_graph optional bool --- mlx/io/safetensor.cpp | 6 +++--- mlx/ops.h | 4 ++-- python/src/load.cpp | 2 +- python/src/ops.cpp | 18 ++++++++++-------- 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/mlx/io/safetensor.cpp b/mlx/io/safetensor.cpp index 10092df6b..14191fde3 100644 --- a/mlx/io/safetensor.cpp +++ b/mlx/io/safetensor.cpp @@ -127,7 +127,7 @@ std::unordered_map load_safetensor( void save_safetensor( std::shared_ptr out_stream, std::unordered_map a, - bool retain_graph) { + std::optional retain_graph_) { //////////////////////////////////////////////////////// // Check file if (!out_stream->good() || !out_stream->is_open()) { @@ -143,7 +143,7 @@ void save_safetensor( }); size_t offset = 0; for (auto& [key, arr] : a) { - arr.eval(retain_graph); + arr.eval(retain_graph_.value_or(arr.is_tracer())); if (arr.nbytes() == 0) { throw std::invalid_argument( "[save_safetensor] cannot serialize an empty array key: " + key); @@ -174,7 +174,7 @@ void save_safetensor( void save_safetensor( const std::string& file_, std::unordered_map a, - bool retain_graph) { + std::optional retain_graph) { // Open and check file std::string file = file_; diff --git a/mlx/ops.h b/mlx/ops.h index b228d4114..fa5163474 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1068,9 +1068,9 @@ std::unordered_map load_safetensor( void save_safetensor( std::shared_ptr in_stream, std::unordered_map, - bool retain_graph = true); + std::optional retain_graph = std::nullopt); void save_safetensor( const std::string& file, std::unordered_map, - bool retain_graph = true); + std::optional retain_graph = std::nullopt); } // namespace mlx::core diff --git a/python/src/load.cpp b/python/src/load.cpp index 78bdbac35..a2e605811 100644 --- a/python/src/load.cpp +++ b/python/src/load.cpp @@ -424,7 +424,7 @@ void mlx_savez_helper( void mlx_save_safetensor_helper( py::object file, py::dict d, - bool retain_graph) { + std::optional retain_graph) { auto arrays_map = d.cast>(); if (py::isinstance(file)) { save_safetensor(py::cast(file), arrays_map, retain_graph); diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 23fc3b611..39148dd57 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2938,14 +2938,14 @@ void init_ops(py::module_& m) { R"pbdoc( load(file: str, format: Optional[str] = None, /, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]] - Load array(s) from a binary file in ``.npy`` or ``.npz`` format. + Load array(s) from a binary file in ``.npy``, ``.npz``, or ``.safetensors`` format. Args: file (file, str): File in which the array is saved format (str, optional): Format of the file. If ``None``, the format - is inferred from the file extension. Supported formats: npy, npz, and safetensors. (default: ``None``) + is inferred from the file extension. Supported formats: ``npy``, ``npz``, and ``safetensors``. (default: ``None``) Returns: - result (array, dict): The loaded array if ``.npy`` file or a dict mapping name to array if ``.npz`` file + result (array, dict): The loaded array if ``.npy`` file or a dict mapping name to array if ``.npz`` or ``.safetensors`` file )pbdoc"); m.def( "save_safetensor", @@ -2953,18 +2953,20 @@ void init_ops(py::module_& m) { "file"_a, "d"_a, py::pos_only(), - "retain_graph"_a = true, + "retain_graph"_a = std::nullopt, py::kw_only(), R"pbdoc( - save_safetensor(file: str, d: Dict[str, array], /, retain_graph: bool = True, *) + save_safetensor(file: str, d: Dict[str, array], /, retain_graph: Optional[bool] = None, *) - Save array(s) to a binary file in ``.safetensors`` format. + Save array(s) to a binary file in ``.safetensors`` format. + For more information on the format see https://huggingface.co/docs/safetensors/index. Args: file (file, str): File in which the array is saved d (Dict[str, array]): The dict mapping name to array to be saved - retain_graph(bool): Optional argument to retain graph - during array evaluation before saving. Default: True + retain_graph(Optional[bool]): Optional argument to retain graph + during array evaluation before saving. If not provided the graph + is retained if we are during a function transformation. Default: None )pbdoc"); m.def( "where",