From b57bd0488daf195b1490db0d5ebeb833b7940e29 Mon Sep 17 00:00:00 2001 From: Diogo Date: Thu, 8 Feb 2024 22:33:15 -0500 Subject: [PATCH] Metadata support for safetensors (#639) * metadata support for safetensors * aliases making it alittle more readable * addressing comments * python binding tests --- mlx/io.h | 29 +++++++++++++---------- mlx/io/gguf.cpp | 15 +++++------- mlx/io/safetensor.cpp | 29 +++++++++++++---------- python/src/load.cpp | 49 ++++++++++++++++++++++++++------------- python/src/load.h | 18 +++++++------- python/src/ops.cpp | 4 +++- python/tests/test_load.py | 9 +++++++ tests/load_tests.cpp | 24 ++++++++++++------- 8 files changed, 108 insertions(+), 69 deletions(-) diff --git a/mlx/io.h b/mlx/io.h index c58e1959e..59866ea27 100644 --- a/mlx/io.h +++ b/mlx/io.h @@ -10,6 +10,14 @@ #include "mlx/stream.h" namespace mlx::core { +using GGUFMetaData = + std::variant>; +using GGUFLoad = std::pair< + std::unordered_map, + std::unordered_map>; +using SafetensorsLoad = std::pair< + std::unordered_map, + std::unordered_map>; /** Save array to out stream in .npy format */ void save(std::shared_ptr out_stream, array a); @@ -24,32 +32,29 @@ array load(std::shared_ptr in_stream, StreamOrDevice s = {}); array load(const std::string& file, StreamOrDevice s = {}); /** Load array map from .safetensors file format */ -std::unordered_map load_safetensors( +SafetensorsLoad load_safetensors( std::shared_ptr in_stream, StreamOrDevice s = {}); -std::unordered_map load_safetensors( +SafetensorsLoad load_safetensors( const std::string& file, StreamOrDevice s = {}); void save_safetensors( std::shared_ptr in_stream, - std::unordered_map); + std::unordered_map, + std::unordered_map metadata = {}); void save_safetensors( const std::string& file, - std::unordered_map); - -using MetaData = - std::variant>; + std::unordered_map, + std::unordered_map metadata = {}); /** Load array map and metadata from .gguf file format */ -std::pair< - std::unordered_map, - std::unordered_map> -load_gguf(const std::string& file, StreamOrDevice s = {}); + +GGUFLoad load_gguf(const std::string& file, StreamOrDevice s = {}); void save_gguf( std::string file, std::unordered_map array_map, - std::unordered_map meta_data = {}); + std::unordered_map meta_data = {}); } // namespace mlx::core diff --git a/mlx/io/gguf.cpp b/mlx/io/gguf.cpp index f4047d1a0..9e7953d6e 100644 --- a/mlx/io/gguf.cpp +++ b/mlx/io/gguf.cpp @@ -82,7 +82,7 @@ void set_mx_value_from_gguf( gguf_ctx* ctx, uint32_t type, gguf_value* val, - MetaData& value) { + GGUFMetaData& value) { switch (type) { case GGUF_VALUE_TYPE_UINT8: value = array(val->uint8, uint8); @@ -191,12 +191,12 @@ void set_mx_value_from_gguf( } } -std::unordered_map load_metadata(gguf_ctx* ctx) { - std::unordered_map metadata; +std::unordered_map load_metadata(gguf_ctx* ctx) { + std::unordered_map metadata; gguf_key key; while (gguf_get_key(ctx, &key)) { std::string key_name = std::string(key.name, key.namelen); - auto& val = metadata.insert({key_name, MetaData{}}).first->second; + auto& val = metadata.insert({key_name, GGUFMetaData{}}).first->second; set_mx_value_from_gguf(ctx, key.type, key.val, val); } return metadata; @@ -230,10 +230,7 @@ std::unordered_map load_arrays(gguf_ctx* ctx) { return array_map; } -std::pair< - std::unordered_map, - std::unordered_map> -load_gguf(const std::string& file, StreamOrDevice s) { +GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) { gguf_ctx* ctx = gguf_open(file.c_str()); if (!ctx) { throw std::runtime_error("[load_gguf] gguf_init failed"); @@ -280,7 +277,7 @@ void append_kv_array( void save_gguf( std::string file, std::unordered_map array_map, - std::unordered_map metadata /* = {} */) { + std::unordered_map metadata /* = {} */) { // Add .gguf to file name if it is not there if (file.length() < 5 || file.substr(file.length() - 5, 5) != ".gguf") { file += ".gguf"; diff --git a/mlx/io/safetensor.cpp b/mlx/io/safetensor.cpp index 7e7868d49..1dd59f444 100644 --- a/mlx/io/safetensor.cpp +++ b/mlx/io/safetensor.cpp @@ -93,7 +93,7 @@ Dtype dtype_from_safetensor_str(std::string str) { } /** Load array from reader in safetensor format */ -std::unordered_map load_safetensors( +SafetensorsLoad load_safetensors( std::shared_ptr in_stream, StreamOrDevice s) { //////////////////////////////////////////////////////// @@ -121,9 +121,12 @@ std::unordered_map load_safetensors( size_t offset = jsonHeaderLength + 8; // Load the arrays using metadata std::unordered_map res; + std::unordered_map metadata_map; for (const auto& item : metadata.items()) { if (item.key() == "__metadata__") { - // ignore metadata for now + for (const auto& meta_item : item.value().items()) { + metadata_map.insert({meta_item.key(), meta_item.value()}); + } continue; } std::string dtype = item.value().at("dtype"); @@ -138,19 +141,18 @@ std::unordered_map load_safetensors( std::vector{}); res.insert({item.key(), loaded_array}); } - return res; + return {res, metadata_map}; } -std::unordered_map load_safetensors( - const std::string& file, - StreamOrDevice s) { +SafetensorsLoad load_safetensors(const std::string& file, StreamOrDevice s) { return load_safetensors(std::make_shared(file), s); } /** Save array to out stream in .npy format */ void save_safetensors( std::shared_ptr out_stream, - std::unordered_map a) { + std::unordered_map a, + std::unordered_map metadata /* = {} */) { //////////////////////////////////////////////////////// // Check file if (!out_stream->good() || !out_stream->is_open()) { @@ -161,9 +163,11 @@ void save_safetensors( //////////////////////////////////////////////////////// // Check array map json parent; - parent["__metadata__"] = json::object({ - {"format", "mlx"}, - }); + json _metadata; + for (auto& [key, value] : metadata) { + _metadata[key] = value; + } + parent["__metadata__"] = _metadata; size_t offset = 0; for (auto& [key, arr] : a) { arr.eval(); @@ -204,7 +208,8 @@ void save_safetensors( void save_safetensors( const std::string& file_, - std::unordered_map a) { + std::unordered_map a, + std::unordered_map metadata /* = {} */) { // Open and check file std::string file = file_; @@ -214,7 +219,7 @@ void save_safetensors( file += ".safetensors"; // Serialize array - save_safetensors(std::make_shared(file), a); + save_safetensors(std::make_shared(file), a, metadata); } } // namespace mlx::core diff --git a/python/src/load.cpp b/python/src/load.cpp index 18e89c7fb..9b6a6861e 100644 --- a/python/src/load.cpp +++ b/python/src/load.cpp @@ -160,31 +160,29 @@ class PyFileReader : public io::Reader { py::object tell_func_; }; -std::unordered_map mlx_load_safetensor_helper( - py::object file, - StreamOrDevice s) { +std::pair< + std::unordered_map, + std::unordered_map> +mlx_load_safetensor_helper(py::object file, StreamOrDevice s) { if (py::isinstance(file)) { // Assume .safetensors file path string return load_safetensors(py::cast(file), s); } else if (is_istream_object(file)) { // If we don't own the stream and it was passed to us, eval immediately - auto arr = load_safetensors(std::make_shared(file), s); + auto res = load_safetensors(std::make_shared(file), s); { py::gil_scoped_release gil; - for (auto& [key, arr] : arr) { + for (auto& [key, arr] : std::get<0>(res)) { arr.eval(); } } - return arr; + return res; } throw std::invalid_argument( "[load_safetensors] Input must be a file-like object, or string"); } -std::pair< - std::unordered_map, - std::unordered_map> -mlx_load_gguf_helper(py::object file, StreamOrDevice s) { +GGUFLoad mlx_load_gguf_helper(py::object file, StreamOrDevice s) { if (py::isinstance(file)) { // Assume .gguf file path string return load_gguf(py::cast(file), s); } @@ -274,12 +272,16 @@ LoadOutputTypes mlx_load_helper( format.emplace(fname.substr(ext + 1)); } - if (return_metadata && format.value() != "gguf") { + if (return_metadata && (format.value() == "npy" || format.value() == "npz")) { throw std::invalid_argument( "[load] metadata not supported for format " + format.value()); } if (format.value() == "safetensors") { - return mlx_load_safetensor_helper(file, s); + auto [dict, metadata] = mlx_load_safetensor_helper(file, s); + if (return_metadata) { + return std::make_pair(dict, metadata); + } + return dict; } else if (format.value() == "npz") { return mlx_load_npz_helper(file, s); } else if (format.value() == "npy") { @@ -444,18 +446,33 @@ void mlx_savez_helper( return; } -void mlx_save_safetensor_helper(py::object file, py::dict d) { +void mlx_save_safetensor_helper( + py::object file, + py::dict d, + std::optional m) { + std::unordered_map metadata_map; + if (m) { + try { + metadata_map = + m.value().cast>(); + } catch (const py::cast_error& e) { + throw std::invalid_argument( + "[save_safetensors] Metadata must be a dictionary with string keys and values"); + } + } else { + metadata_map = std::unordered_map(); + } auto arrays_map = d.cast>(); if (py::isinstance(file)) { { py::gil_scoped_release nogil; - save_safetensors(py::cast(file), arrays_map); + save_safetensors(py::cast(file), arrays_map, metadata_map); } } else if (is_ostream_object(file)) { auto writer = std::make_shared(file); { py::gil_scoped_release nogil; - save_safetensors(writer, arrays_map); + save_safetensors(writer, arrays_map, metadata_map); } } else { throw std::invalid_argument( @@ -471,7 +488,7 @@ void mlx_save_gguf_helper( if (py::isinstance(file)) { if (m) { auto metadata_map = - m.value().cast>(); + m.value().cast>(); { py::gil_scoped_release nogil; save_gguf(py::cast(file), arrays_map, metadata_map); diff --git a/python/src/load.h b/python/src/load.h index dbe0f9cd6..21f0cff32 100644 --- a/python/src/load.h +++ b/python/src/load.h @@ -15,19 +15,17 @@ using namespace mlx::core; using LoadOutputTypes = std::variant< array, std::unordered_map, - std::pair< - std::unordered_map, - std::unordered_map>>; + SafetensorsLoad, + GGUFLoad>; -std::unordered_map mlx_load_safetensor_helper( +SafetensorsLoad mlx_load_safetensor_helper(py::object file, StreamOrDevice s); +void mlx_save_safetensor_helper( py::object file, - StreamOrDevice s); -void mlx_save_safetensor_helper(py::object file, py::dict d); + py::dict d, + std::optional m); + +GGUFLoad mlx_load_gguf_helper(py::object file, StreamOrDevice s); -std::pair< - std::unordered_map, - std::unordered_map> -mlx_load_gguf_helper(py::object file, StreamOrDevice s); void mlx_save_gguf_helper( py::object file, py::dict d, diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 02a401543..8e08e6ca9 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3214,8 +3214,9 @@ void init_ops(py::module_& m) { &mlx_save_safetensor_helper, "file"_a, "arrays"_a, + "metadata"_a = none, R"pbdoc( - save_safetensors(file: str, arrays: Dict[str, array]) + save_safetensors(file: str, arrays: Dict[str, array], metadata: Optional[Dict[str, str]] = None) Save array(s) to a binary file in ``.safetensors`` format. @@ -3225,6 +3226,7 @@ void init_ops(py::module_& m) { Args: file (file, str): File in which the array is saved. arrays (dict(str, array)): The dictionary of names to arrays to be saved. + metadata (dict(str, str), optional): The dictionary of metadata to be saved. )pbdoc"); m.def( "save_gguf", diff --git a/python/tests/test_load.py b/python/tests/test_load.py index a37ba83a9..ab2645bcf 100644 --- a/python/tests/test_load.py +++ b/python/tests/test_load.py @@ -66,6 +66,15 @@ class TestLoad(mlx_tests.MLXTestCase): def test_save_and_load_safetensors(self): if not os.path.isdir(self.test_dir): os.mkdir(self.test_dir) + with self.assertRaises(Exception): + mx.save_safetensors("test", {"a": mx.ones((4, 4))}, {"testing": 0}) + + mx.save_safetensors( + "test", {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"} + ) + res = mx.load("test.safetensors", return_metadata=True) + self.assertEqual(len(res), 2) + self.assertEqual(res[1], {"testing": "test", "format": "mlx"}) for dt in self.dtypes + ["bfloat16"]: with self.subTest(dtype=dt): diff --git a/tests/load_tests.cpp b/tests/load_tests.cpp index 51d1659f3..3a7556b57 100644 --- a/tests/load_tests.cpp +++ b/tests/load_tests.cpp @@ -19,8 +19,14 @@ TEST_CASE("test save_safetensors") { auto map = std::unordered_map(); map.insert({"test", array({1.0, 2.0, 3.0, 4.0})}); map.insert({"test2", ones({2, 2})}); - save_safetensors(file_path, map); - auto dict = load_safetensors(file_path); + auto _metadata = std::unordered_map(); + _metadata.insert({"test", "test"}); + _metadata.insert({"test2", "test2"}); + save_safetensors(file_path, map, _metadata); + auto [dict, metadata] = load_safetensors(file_path); + + CHECK_EQ(metadata, _metadata); + CHECK_EQ(dict.size(), 2); CHECK_EQ(dict.count("test"), 1); CHECK_EQ(dict.count("test2"), 1); @@ -55,7 +61,7 @@ TEST_CASE("test gguf") { } // Test saving and loading string metadata - std::unordered_map original_metadata; + std::unordered_map original_metadata; original_metadata.insert({"test_str", "my string"}); save_gguf(file_path, original_weights, original_metadata); @@ -97,7 +103,7 @@ TEST_CASE("test gguf metadata") { // Scalar array { - std::unordered_map original_metadata; + std::unordered_map original_metadata; original_metadata.insert({"test_arr", array(1.0)}); save_gguf(file_path, original_weights, original_metadata); @@ -111,7 +117,7 @@ TEST_CASE("test gguf metadata") { // 1D Array { - std::unordered_map original_metadata; + std::unordered_map original_metadata; auto arr = array({1.0, 2.0}); original_metadata.insert({"test_arr", arr}); save_gguf(file_path, original_weights, original_metadata); @@ -138,21 +144,21 @@ TEST_CASE("test gguf metadata") { // > 1D array throws { - std::unordered_map original_metadata; + std::unordered_map original_metadata; original_metadata.insert({"test_arr", array({1.0}, {1, 1})}); CHECK_THROWS(save_gguf(file_path, original_weights, original_metadata)); } // empty array throws { - std::unordered_map original_metadata; + std::unordered_map original_metadata; original_metadata.insert({"test_arr", array({})}); CHECK_THROWS(save_gguf(file_path, original_weights, original_metadata)); } // vector of string { - std::unordered_map original_metadata; + std::unordered_map original_metadata; std::vector data = {"data1", "data2", "data1234"}; original_metadata.insert({"meta", data}); save_gguf(file_path, original_weights, original_metadata); @@ -169,7 +175,7 @@ TEST_CASE("test gguf metadata") { // vector of string, string, scalar, and array { - std::unordered_map original_metadata; + std::unordered_map original_metadata; std::vector data = {"data1", "data2", "data1234"}; original_metadata.insert({"meta1", data}); original_metadata.insert({"meta2", array(2.5)});