From 313f6bd9b1bb55545b3f9f10baf765c6e2f86da0 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 22 Dec 2023 21:06:49 -0800 Subject: [PATCH] change name to safetensors --- docs/src/python/ops.rst | 1 + mlx/io/safetensor.cpp | 26 +++++++++++++------------- mlx/ops.h | 10 +++++----- python/src/load.cpp | 14 +++++++------- python/src/ops.cpp | 36 +++++++++++++++++++----------------- python/tests/test_load.py | 4 ++-- tests/load_tests.cpp | 6 +++--- 7 files changed, 50 insertions(+), 47 deletions(-) diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 7e391ec4c..0c5763290 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -83,6 +83,7 @@ Operations save savez savez_compressed + save_safetensors sigmoid sign sin diff --git a/mlx/io/safetensor.cpp b/mlx/io/safetensor.cpp index 14191fde3..c17f713e8 100644 --- a/mlx/io/safetensor.cpp +++ b/mlx/io/safetensor.cpp @@ -69,21 +69,21 @@ Dtype dtype_from_safetensor_str(std::string str) { } /** Load array from reader in safetensor format */ -std::unordered_map load_safetensor( +std::unordered_map load_safetensors( std::shared_ptr in_stream, StreamOrDevice s) { //////////////////////////////////////////////////////// // Open and check file if (!in_stream->good() || !in_stream->is_open()) { throw std::runtime_error( - "[load_safetensor] Failed to open " + in_stream->label()); + "[load_safetensors] Failed to open " + in_stream->label()); } uint64_t jsonHeaderLength = 0; in_stream->read(reinterpret_cast(&jsonHeaderLength), 8); if (jsonHeaderLength <= 0) { throw std::runtime_error( - "[load_safetensor] Invalid json header length " + in_stream->label()); + "[load_safetensors] Invalid json header length " + in_stream->label()); } // Load the json metadata char rawJson[jsonHeaderLength]; @@ -92,7 +92,7 @@ std::unordered_map load_safetensor( // Should always be an object on the top-level if (!metadata.is_object()) { throw std::runtime_error( - "[load_safetensor] Invalid json metadata " + in_stream->label()); + "[load_safetensors] Invalid json metadata " + in_stream->label()); } size_t offset = jsonHeaderLength + 8; // Load the arrays using metadata @@ -117,14 +117,14 @@ std::unordered_map load_safetensor( return res; } -std::unordered_map load_safetensor( +std::unordered_map load_safetensors( const std::string& file, StreamOrDevice s) { - return load_safetensor(std::make_shared(file), s); + return load_safetensors(std::make_shared(file), s); } /** Save array to out stream in .npy format */ -void save_safetensor( +void save_safetensors( std::shared_ptr out_stream, std::unordered_map a, std::optional retain_graph_) { @@ -132,7 +132,7 @@ void save_safetensor( // Check file if (!out_stream->good() || !out_stream->is_open()) { throw std::runtime_error( - "[save_safetensor] Failed to open " + out_stream->label()); + "[save_safetensors] Failed to open " + out_stream->label()); } //////////////////////////////////////////////////////// @@ -146,12 +146,12 @@ void save_safetensor( arr.eval(retain_graph_.value_or(arr.is_tracer())); if (arr.nbytes() == 0) { throw std::invalid_argument( - "[save_safetensor] cannot serialize an empty array key: " + key); + "[save_safetensors] cannot serialize an empty array key: " + key); } if (!arr.flags().contiguous) { throw std::invalid_argument( - "[save_safetensor] cannot serialize a non-contiguous array key: " + + "[save_safetensors] cannot serialize a non-contiguous array key: " + key); } json child; @@ -171,7 +171,7 @@ void save_safetensor( } } -void save_safetensor( +void save_safetensors( const std::string& file_, std::unordered_map a, std::optional retain_graph) { @@ -184,7 +184,7 @@ void save_safetensor( file += ".safetensors"; // Serialize array - save_safetensor(std::make_shared(file), a, retain_graph); + save_safetensors(std::make_shared(file), a, retain_graph); } -} // namespace mlx::core \ No newline at end of file +} // namespace mlx::core diff --git a/mlx/ops.h b/mlx/ops.h index fa5163474..e1abac6fb 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1057,19 +1057,19 @@ array dequantize( int bits = 4, StreamOrDevice s = {}); -/** Load array map from .safetensor file format */ -std::unordered_map load_safetensor( +/** Load array map from .safetensors file format */ +std::unordered_map load_safetensors( std::shared_ptr in_stream, StreamOrDevice s = {}); -std::unordered_map load_safetensor( +std::unordered_map load_safetensors( const std::string& file, StreamOrDevice s = {}); -void save_safetensor( +void save_safetensors( std::shared_ptr in_stream, std::unordered_map, std::optional retain_graph = std::nullopt); -void save_safetensor( +void save_safetensors( const std::string& file, std::unordered_map, std::optional retain_graph = std::nullopt); diff --git a/python/src/load.cpp b/python/src/load.cpp index f6609faea..a63e5063e 100644 --- a/python/src/load.cpp +++ b/python/src/load.cpp @@ -164,10 +164,10 @@ std::unordered_map mlx_load_safetensor_helper( py::object file, StreamOrDevice s) { if (py::isinstance(file)) { // Assume .safetensors file path string - return {load_safetensor(py::cast(file), s)}; + return {load_safetensors(py::cast(file), s)}; } else if (is_istream_object(file)) { // If we don't own the stream and it was passed to us, eval immediately - auto arr = load_safetensor(std::make_shared(file), s); + auto arr = load_safetensors(std::make_shared(file), s); { py::gil_scoped_release gil; for (auto& [key, arr] : arr) { @@ -178,7 +178,7 @@ std::unordered_map mlx_load_safetensor_helper( } throw std::invalid_argument( - "[load_safetensor] Input must be a file-like object, or string"); + "[load_safetensors] Input must be a file-like object, or string"); } std::unordered_map mlx_load_npz_helper( @@ -427,18 +427,18 @@ void mlx_save_safetensor_helper( std::optional retain_graph) { auto arrays_map = d.cast>(); if (py::isinstance(file)) { - save_safetensor(py::cast(file), arrays_map, retain_graph); + save_safetensors(py::cast(file), arrays_map, retain_graph); return; } else if (is_ostream_object(file)) { auto writer = std::make_shared(file); { py::gil_scoped_release gil; - save_safetensor(writer, arrays_map, retain_graph); + save_safetensors(writer, arrays_map, retain_graph); } return; } throw std::invalid_argument( - "[save_safetensor] Input must be a file-like object, or string"); -} \ No newline at end of file + "[save_safetensors] Input must be a file-like object, or string"); +} diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 39148dd57..db405b127 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2867,11 +2867,9 @@ void init_ops(py::module_& m) { Args: file (str): File to which the array is saved arr (array): Array to be saved. - retain_graph (bool, optional): Optional argument to retain graph - during array evaluation before saving. If not provided the graph - is retained if we are during a function transformation. Default: - None - + retain_graph (bool, optional): Whether or not to retain the graph + during array evaluation. If left unspecified the graph is retained + only if saving is done in a function transformation. Default: ``None`` )pbdoc"); m.def( "savez", @@ -2941,32 +2939,36 @@ void init_ops(py::module_& m) { Load array(s) from a binary file in ``.npy``, ``.npz``, or ``.safetensors`` format. Args: - file (file, str): File in which the array is saved + file (file, str): File in which the array is saved. format (str, optional): Format of the file. If ``None``, the format - is inferred from the file extension. Supported formats: ``npy``, ``npz``, and ``safetensors``. (default: ``None``) + is inferred from the file extension. Supported formats: ``npy``, + ``npz``, and ``safetensors``. Default: ``None``. Returns: - result (array, dict): The loaded array if ``.npy`` file or a dict mapping name to array if ``.npz`` or ``.safetensors`` file + result (array, dict): + A single array if loading from a ``.npy`` file or a dict mapping + names to arrays if loading from a ``.npz`` or ``.safetensors`` file. )pbdoc"); m.def( - "save_safetensor", + "save_safetensors", &mlx_save_safetensor_helper, "file"_a, - "d"_a, + "arrays"_a, py::pos_only(), "retain_graph"_a = std::nullopt, py::kw_only(), R"pbdoc( - save_safetensor(file: str, d: Dict[str, array], /, retain_graph: Optional[bool] = None, *) + save_safetensors(file: str, arrays: Dict[str, array], /, retain_graph: Optional[bool] = None) + + Save array(s) to a binary file in ``.safetensors`` format. - Save array(s) to a binary file in ``.safetensors`` format. For more information on the format see https://huggingface.co/docs/safetensors/index. Args: - file (file, str): File in which the array is saved - d (Dict[str, array]): The dict mapping name to array to be saved - retain_graph(Optional[bool]): Optional argument to retain graph - during array evaluation before saving. If not provided the graph - is retained if we are during a function transformation. Default: None + file (file, str): File in which the array is saved> + arrays (dict(str, array)): The dictionary of names to arrays to be saved. + retain_graph (bool, optional): Whether or not to retain the graph + during array evaluation. If left unspecified the graph is retained + only if saving is done in a function transformation. Default: ``None``. )pbdoc"); m.def( "where", diff --git a/python/tests/test_load.py b/python/tests/test_load.py index c38fff661..2ee550604 100644 --- a/python/tests/test_load.py +++ b/python/tests/test_load.py @@ -64,7 +64,7 @@ class TestLoad(mlx_tests.MLXTestCase): load_arr_mlx_npy = np.load(save_file_mlx) self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy)) - def test_save_and_load_safetensor(self): + def test_save_and_load_safetensors(self): if not os.path.isdir(self.test_dir): os.mkdir(self.test_dir) @@ -82,7 +82,7 @@ class TestLoad(mlx_tests.MLXTestCase): } with open(save_file_mlx, "wb") as f: - mx.save_safetensor(f, save_dict) + mx.save_safetensors(f, save_dict) with open(save_file_mlx, "rb") as f: load_dict = mx.load(f) diff --git a/tests/load_tests.cpp b/tests/load_tests.cpp index 873322f44..edff1aff6 100644 --- a/tests/load_tests.cpp +++ b/tests/load_tests.cpp @@ -14,13 +14,13 @@ std::string get_temp_file(const std::string& name) { return std::filesystem::temp_directory_path().append(name); } -TEST_CASE("test save_safetensor") { +TEST_CASE("test save_safetensors") { std::string file_path = get_temp_file("test_arr.safetensors"); auto map = std::unordered_map(); map.insert({"test", array({1.0, 2.0, 3.0, 4.0})}); map.insert({"test2", ones({2, 2})}); - save_safetensor(file_path, map); - auto safeDict = load_safetensor(file_path); + save_safetensors(file_path, map); + auto safeDict = load_safetensors(file_path); CHECK_EQ(safeDict.size(), 2); CHECK_EQ(safeDict.count("test"), 1); CHECK_EQ(safeDict.count("test2"), 1);