From c88d3174aae239fd4c404ae8916dd790d642cfa1 Mon Sep 17 00:00:00 2001 From: dc-dc-dc Date: Mon, 18 Dec 2023 20:25:51 -0500 Subject: [PATCH] added save bindings and fixed header --- mlx/ops.h | 2 +- python/src/load.cpp | 19 +++++++++++++++++++ python/src/load.h | 2 ++ python/src/ops.cpp | 16 ++++++++++++++++ 4 files changed, 38 insertions(+), 1 deletion(-) diff --git a/mlx/ops.h b/mlx/ops.h index bd002b5df..d8bc05c72 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1066,7 +1066,7 @@ std::unordered_map load_safetensor( StreamOrDevice s = {}); void save_safetensor( - std::shared_ptr in_stream, + std::shared_ptr in_stream, std::unordered_map); void save_safetensor( const std::string& file, diff --git a/python/src/load.cpp b/python/src/load.cpp index 82b018189..f859febe3 100644 --- a/python/src/load.cpp +++ b/python/src/load.cpp @@ -382,3 +382,22 @@ void mlx_savez_helper( return; } + +void mlx_save_safetensor_helper(py::object file, py::dict d) { + auto arrays_map = d.cast>(); + if (py::isinstance(file)) { + save_safetensor(py::cast(file), arrays_map); + return; + } else if (is_ostream_object(file)) { + auto writer = std::make_shared(file); + { + py::gil_scoped_release gil; + save_safetensor(writer, arrays_map); + } + + return; + } + + throw std::invalid_argument( + "[save_safetensor] Input must be a file-like object, string, or pathlib.Path"); +} \ No newline at end of file diff --git a/python/src/load.h b/python/src/load.h index 20f0c79dd..1ced6e35d 100644 --- a/python/src/load.h +++ b/python/src/load.h @@ -15,6 +15,8 @@ using DictOrArray = std::variant>; std::unordered_map mlx_load_safetensor_helper( py::object file, StreamOrDevice s); +void mlx_save_safetensor_helper(py::object file, py::dict d); + DictOrArray mlx_load_helper(py::object file, StreamOrDevice s); void mlx_save_helper( py::object file, diff --git a/python/src/ops.cpp b/python/src/ops.cpp index a49b82317..2189a0a10 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2963,6 +2963,22 @@ void init_ops(py::module_& m) { Returns: result dict: The loaded dict mapping name to array from the ``.safetensors`` file )pbdoc"); + m.def( + "save_safetensor", + &mlx_save_safetensor_helper, + "file"_a, + "d"_a, + py::pos_only(), + py::kw_only(), + R"pbdoc( + save_safetensor(file: str, d: Dict[str, array], /, *, stream: Union[None, Stream, Device] = None) + + Save array(s) to a binary file in ``.safetensors`` format. + + Args: + file (file, str): File in which the array is saved + d (Dict[str, array]): The dict mapping name to array to be saved + )pbdoc"); m.def( "where", [](const ScalarOrArray& condition,