mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Metadata support for safetensors (#639)
* metadata support for safetensors * aliases making it alittle more readable * addressing comments * python binding tests
This commit is contained in:
		
							
								
								
									
										29
									
								
								mlx/io.h
									
									
									
									
									
								
							
							
						
						
									
										29
									
								
								mlx/io.h
									
									
									
									
									
								
							| @@ -10,6 +10,14 @@ | ||||
| #include "mlx/stream.h" | ||||
|  | ||||
| namespace mlx::core { | ||||
| using GGUFMetaData = | ||||
|     std::variant<std::monostate, array, std::string, std::vector<std::string>>; | ||||
| using GGUFLoad = std::pair< | ||||
|     std::unordered_map<std::string, array>, | ||||
|     std::unordered_map<std::string, GGUFMetaData>>; | ||||
| using SafetensorsLoad = std::pair< | ||||
|     std::unordered_map<std::string, array>, | ||||
|     std::unordered_map<std::string, std::string>>; | ||||
|  | ||||
| /** Save array to out stream in .npy format */ | ||||
| void save(std::shared_ptr<io::Writer> out_stream, array a); | ||||
| @@ -24,32 +32,29 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {}); | ||||
| array load(const std::string& file, StreamOrDevice s = {}); | ||||
|  | ||||
| /** Load array map from .safetensors file format */ | ||||
| std::unordered_map<std::string, array> load_safetensors( | ||||
| SafetensorsLoad load_safetensors( | ||||
|     std::shared_ptr<io::Reader> in_stream, | ||||
|     StreamOrDevice s = {}); | ||||
| std::unordered_map<std::string, array> load_safetensors( | ||||
| SafetensorsLoad load_safetensors( | ||||
|     const std::string& file, | ||||
|     StreamOrDevice s = {}); | ||||
|  | ||||
| void save_safetensors( | ||||
|     std::shared_ptr<io::Writer> in_stream, | ||||
|     std::unordered_map<std::string, array>); | ||||
|     std::unordered_map<std::string, array>, | ||||
|     std::unordered_map<std::string, std::string> metadata = {}); | ||||
| void save_safetensors( | ||||
|     const std::string& file, | ||||
|     std::unordered_map<std::string, array>); | ||||
|  | ||||
| using MetaData = | ||||
|     std::variant<std::monostate, array, std::string, std::vector<std::string>>; | ||||
|     std::unordered_map<std::string, array>, | ||||
|     std::unordered_map<std::string, std::string> metadata = {}); | ||||
|  | ||||
| /** Load array map and metadata from .gguf file format */ | ||||
| std::pair< | ||||
|     std::unordered_map<std::string, array>, | ||||
|     std::unordered_map<std::string, MetaData>> | ||||
| load_gguf(const std::string& file, StreamOrDevice s = {}); | ||||
|  | ||||
| GGUFLoad load_gguf(const std::string& file, StreamOrDevice s = {}); | ||||
|  | ||||
| void save_gguf( | ||||
|     std::string file, | ||||
|     std::unordered_map<std::string, array> array_map, | ||||
|     std::unordered_map<std::string, MetaData> meta_data = {}); | ||||
|     std::unordered_map<std::string, GGUFMetaData> meta_data = {}); | ||||
|  | ||||
| } // namespace mlx::core | ||||
|   | ||||
| @@ -82,7 +82,7 @@ void set_mx_value_from_gguf( | ||||
|     gguf_ctx* ctx, | ||||
|     uint32_t type, | ||||
|     gguf_value* val, | ||||
|     MetaData& value) { | ||||
|     GGUFMetaData& value) { | ||||
|   switch (type) { | ||||
|     case GGUF_VALUE_TYPE_UINT8: | ||||
|       value = array(val->uint8, uint8); | ||||
| @@ -191,12 +191,12 @@ void set_mx_value_from_gguf( | ||||
|   } | ||||
| } | ||||
|  | ||||
| std::unordered_map<std::string, MetaData> load_metadata(gguf_ctx* ctx) { | ||||
|   std::unordered_map<std::string, MetaData> metadata; | ||||
| std::unordered_map<std::string, GGUFMetaData> load_metadata(gguf_ctx* ctx) { | ||||
|   std::unordered_map<std::string, GGUFMetaData> metadata; | ||||
|   gguf_key key; | ||||
|   while (gguf_get_key(ctx, &key)) { | ||||
|     std::string key_name = std::string(key.name, key.namelen); | ||||
|     auto& val = metadata.insert({key_name, MetaData{}}).first->second; | ||||
|     auto& val = metadata.insert({key_name, GGUFMetaData{}}).first->second; | ||||
|     set_mx_value_from_gguf(ctx, key.type, key.val, val); | ||||
|   } | ||||
|   return metadata; | ||||
| @@ -230,10 +230,7 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) { | ||||
|   return array_map; | ||||
| } | ||||
|  | ||||
| std::pair< | ||||
|     std::unordered_map<std::string, array>, | ||||
|     std::unordered_map<std::string, MetaData>> | ||||
| load_gguf(const std::string& file, StreamOrDevice s) { | ||||
| GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) { | ||||
|   gguf_ctx* ctx = gguf_open(file.c_str()); | ||||
|   if (!ctx) { | ||||
|     throw std::runtime_error("[load_gguf] gguf_init failed"); | ||||
| @@ -280,7 +277,7 @@ void append_kv_array( | ||||
| void save_gguf( | ||||
|     std::string file, | ||||
|     std::unordered_map<std::string, array> array_map, | ||||
|     std::unordered_map<std::string, MetaData> metadata /* = {} */) { | ||||
|     std::unordered_map<std::string, GGUFMetaData> metadata /* = {} */) { | ||||
|   // Add .gguf to file name if it is not there | ||||
|   if (file.length() < 5 || file.substr(file.length() - 5, 5) != ".gguf") { | ||||
|     file += ".gguf"; | ||||
|   | ||||
| @@ -93,7 +93,7 @@ Dtype dtype_from_safetensor_str(std::string str) { | ||||
| } | ||||
|  | ||||
| /** Load array from reader in safetensor format */ | ||||
| std::unordered_map<std::string, array> load_safetensors( | ||||
| SafetensorsLoad load_safetensors( | ||||
|     std::shared_ptr<io::Reader> in_stream, | ||||
|     StreamOrDevice s) { | ||||
|   //////////////////////////////////////////////////////// | ||||
| @@ -121,9 +121,12 @@ std::unordered_map<std::string, array> load_safetensors( | ||||
|   size_t offset = jsonHeaderLength + 8; | ||||
|   // Load the arrays using metadata | ||||
|   std::unordered_map<std::string, array> res; | ||||
|   std::unordered_map<std::string, std::string> metadata_map; | ||||
|   for (const auto& item : metadata.items()) { | ||||
|     if (item.key() == "__metadata__") { | ||||
|       // ignore metadata for now | ||||
|       for (const auto& meta_item : item.value().items()) { | ||||
|         metadata_map.insert({meta_item.key(), meta_item.value()}); | ||||
|       } | ||||
|       continue; | ||||
|     } | ||||
|     std::string dtype = item.value().at("dtype"); | ||||
| @@ -138,19 +141,18 @@ std::unordered_map<std::string, array> load_safetensors( | ||||
|         std::vector<array>{}); | ||||
|     res.insert({item.key(), loaded_array}); | ||||
|   } | ||||
|   return res; | ||||
|   return {res, metadata_map}; | ||||
| } | ||||
|  | ||||
| std::unordered_map<std::string, array> load_safetensors( | ||||
|     const std::string& file, | ||||
|     StreamOrDevice s) { | ||||
| SafetensorsLoad load_safetensors(const std::string& file, StreamOrDevice s) { | ||||
|   return load_safetensors(std::make_shared<io::FileReader>(file), s); | ||||
| } | ||||
|  | ||||
| /** Save array to out stream in .npy format */ | ||||
| void save_safetensors( | ||||
|     std::shared_ptr<io::Writer> out_stream, | ||||
|     std::unordered_map<std::string, array> a) { | ||||
|     std::unordered_map<std::string, array> a, | ||||
|     std::unordered_map<std::string, std::string> metadata /* = {} */) { | ||||
|   //////////////////////////////////////////////////////// | ||||
|   // Check file | ||||
|   if (!out_stream->good() || !out_stream->is_open()) { | ||||
| @@ -161,9 +163,11 @@ void save_safetensors( | ||||
|   //////////////////////////////////////////////////////// | ||||
|   // Check array map | ||||
|   json parent; | ||||
|   parent["__metadata__"] = json::object({ | ||||
|       {"format", "mlx"}, | ||||
|   }); | ||||
|   json _metadata; | ||||
|   for (auto& [key, value] : metadata) { | ||||
|     _metadata[key] = value; | ||||
|   } | ||||
|   parent["__metadata__"] = _metadata; | ||||
|   size_t offset = 0; | ||||
|   for (auto& [key, arr] : a) { | ||||
|     arr.eval(); | ||||
| @@ -204,7 +208,8 @@ void save_safetensors( | ||||
|  | ||||
| void save_safetensors( | ||||
|     const std::string& file_, | ||||
|     std::unordered_map<std::string, array> a) { | ||||
|     std::unordered_map<std::string, array> a, | ||||
|     std::unordered_map<std::string, std::string> metadata /* = {} */) { | ||||
|   // Open and check file | ||||
|   std::string file = file_; | ||||
|  | ||||
| @@ -214,7 +219,7 @@ void save_safetensors( | ||||
|     file += ".safetensors"; | ||||
|  | ||||
|   // Serialize array | ||||
|   save_safetensors(std::make_shared<io::FileWriter>(file), a); | ||||
|   save_safetensors(std::make_shared<io::FileWriter>(file), a, metadata); | ||||
| } | ||||
|  | ||||
| } // namespace mlx::core | ||||
|   | ||||
| @@ -160,31 +160,29 @@ class PyFileReader : public io::Reader { | ||||
|   py::object tell_func_; | ||||
| }; | ||||
|  | ||||
| std::unordered_map<std::string, array> mlx_load_safetensor_helper( | ||||
|     py::object file, | ||||
|     StreamOrDevice s) { | ||||
| std::pair< | ||||
|     std::unordered_map<std::string, array>, | ||||
|     std::unordered_map<std::string, std::string>> | ||||
| mlx_load_safetensor_helper(py::object file, StreamOrDevice s) { | ||||
|   if (py::isinstance<py::str>(file)) { // Assume .safetensors file path string | ||||
|     return load_safetensors(py::cast<std::string>(file), s); | ||||
|   } else if (is_istream_object(file)) { | ||||
|     // If we don't own the stream and it was passed to us, eval immediately | ||||
|     auto arr = load_safetensors(std::make_shared<PyFileReader>(file), s); | ||||
|     auto res = load_safetensors(std::make_shared<PyFileReader>(file), s); | ||||
|     { | ||||
|       py::gil_scoped_release gil; | ||||
|       for (auto& [key, arr] : arr) { | ||||
|       for (auto& [key, arr] : std::get<0>(res)) { | ||||
|         arr.eval(); | ||||
|       } | ||||
|     } | ||||
|     return arr; | ||||
|     return res; | ||||
|   } | ||||
|  | ||||
|   throw std::invalid_argument( | ||||
|       "[load_safetensors] Input must be a file-like object, or string"); | ||||
| } | ||||
|  | ||||
| std::pair< | ||||
|     std::unordered_map<std::string, array>, | ||||
|     std::unordered_map<std::string, MetaData>> | ||||
| mlx_load_gguf_helper(py::object file, StreamOrDevice s) { | ||||
| GGUFLoad mlx_load_gguf_helper(py::object file, StreamOrDevice s) { | ||||
|   if (py::isinstance<py::str>(file)) { // Assume .gguf file path string | ||||
|     return load_gguf(py::cast<std::string>(file), s); | ||||
|   } | ||||
| @@ -274,12 +272,16 @@ LoadOutputTypes mlx_load_helper( | ||||
|     format.emplace(fname.substr(ext + 1)); | ||||
|   } | ||||
|  | ||||
|   if (return_metadata && format.value() != "gguf") { | ||||
|   if (return_metadata && (format.value() == "npy" || format.value() == "npz")) { | ||||
|     throw std::invalid_argument( | ||||
|         "[load] metadata not supported for format " + format.value()); | ||||
|   } | ||||
|   if (format.value() == "safetensors") { | ||||
|     return mlx_load_safetensor_helper(file, s); | ||||
|     auto [dict, metadata] = mlx_load_safetensor_helper(file, s); | ||||
|     if (return_metadata) { | ||||
|       return std::make_pair(dict, metadata); | ||||
|     } | ||||
|     return dict; | ||||
|   } else if (format.value() == "npz") { | ||||
|     return mlx_load_npz_helper(file, s); | ||||
|   } else if (format.value() == "npy") { | ||||
| @@ -444,18 +446,33 @@ void mlx_savez_helper( | ||||
|   return; | ||||
| } | ||||
|  | ||||
| void mlx_save_safetensor_helper(py::object file, py::dict d) { | ||||
| void mlx_save_safetensor_helper( | ||||
|     py::object file, | ||||
|     py::dict d, | ||||
|     std::optional<py::dict> m) { | ||||
|   std::unordered_map<std::string, std::string> metadata_map; | ||||
|   if (m) { | ||||
|     try { | ||||
|       metadata_map = | ||||
|           m.value().cast<std::unordered_map<std::string, std::string>>(); | ||||
|     } catch (const py::cast_error& e) { | ||||
|       throw std::invalid_argument( | ||||
|           "[save_safetensors] Metadata must be a dictionary with string keys and values"); | ||||
|     } | ||||
|   } else { | ||||
|     metadata_map = std::unordered_map<std::string, std::string>(); | ||||
|   } | ||||
|   auto arrays_map = d.cast<std::unordered_map<std::string, array>>(); | ||||
|   if (py::isinstance<py::str>(file)) { | ||||
|     { | ||||
|       py::gil_scoped_release nogil; | ||||
|       save_safetensors(py::cast<std::string>(file), arrays_map); | ||||
|       save_safetensors(py::cast<std::string>(file), arrays_map, metadata_map); | ||||
|     } | ||||
|   } else if (is_ostream_object(file)) { | ||||
|     auto writer = std::make_shared<PyFileWriter>(file); | ||||
|     { | ||||
|       py::gil_scoped_release nogil; | ||||
|       save_safetensors(writer, arrays_map); | ||||
|       save_safetensors(writer, arrays_map, metadata_map); | ||||
|     } | ||||
|   } else { | ||||
|     throw std::invalid_argument( | ||||
| @@ -471,7 +488,7 @@ void mlx_save_gguf_helper( | ||||
|   if (py::isinstance<py::str>(file)) { | ||||
|     if (m) { | ||||
|       auto metadata_map = | ||||
|           m.value().cast<std::unordered_map<std::string, MetaData>>(); | ||||
|           m.value().cast<std::unordered_map<std::string, GGUFMetaData>>(); | ||||
|       { | ||||
|         py::gil_scoped_release nogil; | ||||
|         save_gguf(py::cast<std::string>(file), arrays_map, metadata_map); | ||||
|   | ||||
| @@ -15,19 +15,17 @@ using namespace mlx::core; | ||||
| using LoadOutputTypes = std::variant< | ||||
|     array, | ||||
|     std::unordered_map<std::string, array>, | ||||
|     std::pair< | ||||
|         std::unordered_map<std::string, array>, | ||||
|         std::unordered_map<std::string, MetaData>>>; | ||||
|     SafetensorsLoad, | ||||
|     GGUFLoad>; | ||||
|  | ||||
| std::unordered_map<std::string, array> mlx_load_safetensor_helper( | ||||
| SafetensorsLoad mlx_load_safetensor_helper(py::object file, StreamOrDevice s); | ||||
| void mlx_save_safetensor_helper( | ||||
|     py::object file, | ||||
|     StreamOrDevice s); | ||||
| void mlx_save_safetensor_helper(py::object file, py::dict d); | ||||
|     py::dict d, | ||||
|     std::optional<py::dict> m); | ||||
|  | ||||
| GGUFLoad mlx_load_gguf_helper(py::object file, StreamOrDevice s); | ||||
|  | ||||
| std::pair< | ||||
|     std::unordered_map<std::string, array>, | ||||
|     std::unordered_map<std::string, MetaData>> | ||||
| mlx_load_gguf_helper(py::object file, StreamOrDevice s); | ||||
| void mlx_save_gguf_helper( | ||||
|     py::object file, | ||||
|     py::dict d, | ||||
|   | ||||
| @@ -3214,8 +3214,9 @@ void init_ops(py::module_& m) { | ||||
|       &mlx_save_safetensor_helper, | ||||
|       "file"_a, | ||||
|       "arrays"_a, | ||||
|       "metadata"_a = none, | ||||
|       R"pbdoc( | ||||
|         save_safetensors(file: str, arrays: Dict[str, array]) | ||||
|         save_safetensors(file: str, arrays: Dict[str, array], metadata: Optional[Dict[str, str]] = None) | ||||
|  | ||||
|         Save array(s) to a binary file in ``.safetensors`` format. | ||||
|  | ||||
| @@ -3225,6 +3226,7 @@ void init_ops(py::module_& m) { | ||||
|         Args: | ||||
|             file (file, str): File in which the array is saved. | ||||
|             arrays (dict(str, array)): The dictionary of names to arrays to be saved. | ||||
|             metadata (dict(str, str), optional): The dictionary of metadata to be saved. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "save_gguf", | ||||
|   | ||||
| @@ -66,6 +66,15 @@ class TestLoad(mlx_tests.MLXTestCase): | ||||
|     def test_save_and_load_safetensors(self): | ||||
|         if not os.path.isdir(self.test_dir): | ||||
|             os.mkdir(self.test_dir) | ||||
|         with self.assertRaises(Exception): | ||||
|             mx.save_safetensors("test", {"a": mx.ones((4, 4))}, {"testing": 0}) | ||||
|  | ||||
|         mx.save_safetensors( | ||||
|             "test", {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"} | ||||
|         ) | ||||
|         res = mx.load("test.safetensors", return_metadata=True) | ||||
|         self.assertEqual(len(res), 2) | ||||
|         self.assertEqual(res[1], {"testing": "test", "format": "mlx"}) | ||||
|  | ||||
|         for dt in self.dtypes + ["bfloat16"]: | ||||
|             with self.subTest(dtype=dt): | ||||
|   | ||||
| @@ -19,8 +19,14 @@ TEST_CASE("test save_safetensors") { | ||||
|   auto map = std::unordered_map<std::string, array>(); | ||||
|   map.insert({"test", array({1.0, 2.0, 3.0, 4.0})}); | ||||
|   map.insert({"test2", ones({2, 2})}); | ||||
|   save_safetensors(file_path, map); | ||||
|   auto dict = load_safetensors(file_path); | ||||
|   auto _metadata = std::unordered_map<std::string, std::string>(); | ||||
|   _metadata.insert({"test", "test"}); | ||||
|   _metadata.insert({"test2", "test2"}); | ||||
|   save_safetensors(file_path, map, _metadata); | ||||
|   auto [dict, metadata] = load_safetensors(file_path); | ||||
|  | ||||
|   CHECK_EQ(metadata, _metadata); | ||||
|  | ||||
|   CHECK_EQ(dict.size(), 2); | ||||
|   CHECK_EQ(dict.count("test"), 1); | ||||
|   CHECK_EQ(dict.count("test2"), 1); | ||||
| @@ -55,7 +61,7 @@ TEST_CASE("test gguf") { | ||||
|   } | ||||
|  | ||||
|   // Test saving and loading string metadata | ||||
|   std::unordered_map<std::string, MetaData> original_metadata; | ||||
|   std::unordered_map<std::string, GGUFMetaData> original_metadata; | ||||
|   original_metadata.insert({"test_str", "my string"}); | ||||
|  | ||||
|   save_gguf(file_path, original_weights, original_metadata); | ||||
| @@ -97,7 +103,7 @@ TEST_CASE("test gguf metadata") { | ||||
|  | ||||
|   // Scalar array | ||||
|   { | ||||
|     std::unordered_map<std::string, MetaData> original_metadata; | ||||
|     std::unordered_map<std::string, GGUFMetaData> original_metadata; | ||||
|     original_metadata.insert({"test_arr", array(1.0)}); | ||||
|     save_gguf(file_path, original_weights, original_metadata); | ||||
|  | ||||
| @@ -111,7 +117,7 @@ TEST_CASE("test gguf metadata") { | ||||
|  | ||||
|   // 1D Array | ||||
|   { | ||||
|     std::unordered_map<std::string, MetaData> original_metadata; | ||||
|     std::unordered_map<std::string, GGUFMetaData> original_metadata; | ||||
|     auto arr = array({1.0, 2.0}); | ||||
|     original_metadata.insert({"test_arr", arr}); | ||||
|     save_gguf(file_path, original_weights, original_metadata); | ||||
| @@ -138,21 +144,21 @@ TEST_CASE("test gguf metadata") { | ||||
|  | ||||
|   // > 1D array throws | ||||
|   { | ||||
|     std::unordered_map<std::string, MetaData> original_metadata; | ||||
|     std::unordered_map<std::string, GGUFMetaData> original_metadata; | ||||
|     original_metadata.insert({"test_arr", array({1.0}, {1, 1})}); | ||||
|     CHECK_THROWS(save_gguf(file_path, original_weights, original_metadata)); | ||||
|   } | ||||
|  | ||||
|   // empty array throws | ||||
|   { | ||||
|     std::unordered_map<std::string, MetaData> original_metadata; | ||||
|     std::unordered_map<std::string, GGUFMetaData> original_metadata; | ||||
|     original_metadata.insert({"test_arr", array({})}); | ||||
|     CHECK_THROWS(save_gguf(file_path, original_weights, original_metadata)); | ||||
|   } | ||||
|  | ||||
|   // vector of string | ||||
|   { | ||||
|     std::unordered_map<std::string, MetaData> original_metadata; | ||||
|     std::unordered_map<std::string, GGUFMetaData> original_metadata; | ||||
|     std::vector<std::string> data = {"data1", "data2", "data1234"}; | ||||
|     original_metadata.insert({"meta", data}); | ||||
|     save_gguf(file_path, original_weights, original_metadata); | ||||
| @@ -169,7 +175,7 @@ TEST_CASE("test gguf metadata") { | ||||
|  | ||||
|   // vector of string, string, scalar, and array | ||||
|   { | ||||
|     std::unordered_map<std::string, MetaData> original_metadata; | ||||
|     std::unordered_map<std::string, GGUFMetaData> original_metadata; | ||||
|     std::vector<std::string> data = {"data1", "data2", "data1234"}; | ||||
|     original_metadata.insert({"meta1", data}); | ||||
|     original_metadata.insert({"meta2", array(2.5)}); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Diogo
					Diogo