mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Metadata support for safetensors (#639)
* metadata support for safetensors * aliases making it alittle more readable * addressing comments * python binding tests
This commit is contained in:
		| @@ -82,7 +82,7 @@ void set_mx_value_from_gguf( | ||||
|     gguf_ctx* ctx, | ||||
|     uint32_t type, | ||||
|     gguf_value* val, | ||||
|     MetaData& value) { | ||||
|     GGUFMetaData& value) { | ||||
|   switch (type) { | ||||
|     case GGUF_VALUE_TYPE_UINT8: | ||||
|       value = array(val->uint8, uint8); | ||||
| @@ -191,12 +191,12 @@ void set_mx_value_from_gguf( | ||||
|   } | ||||
| } | ||||
|  | ||||
| std::unordered_map<std::string, MetaData> load_metadata(gguf_ctx* ctx) { | ||||
|   std::unordered_map<std::string, MetaData> metadata; | ||||
| std::unordered_map<std::string, GGUFMetaData> load_metadata(gguf_ctx* ctx) { | ||||
|   std::unordered_map<std::string, GGUFMetaData> metadata; | ||||
|   gguf_key key; | ||||
|   while (gguf_get_key(ctx, &key)) { | ||||
|     std::string key_name = std::string(key.name, key.namelen); | ||||
|     auto& val = metadata.insert({key_name, MetaData{}}).first->second; | ||||
|     auto& val = metadata.insert({key_name, GGUFMetaData{}}).first->second; | ||||
|     set_mx_value_from_gguf(ctx, key.type, key.val, val); | ||||
|   } | ||||
|   return metadata; | ||||
| @@ -230,10 +230,7 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) { | ||||
|   return array_map; | ||||
| } | ||||
|  | ||||
| std::pair< | ||||
|     std::unordered_map<std::string, array>, | ||||
|     std::unordered_map<std::string, MetaData>> | ||||
| load_gguf(const std::string& file, StreamOrDevice s) { | ||||
| GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) { | ||||
|   gguf_ctx* ctx = gguf_open(file.c_str()); | ||||
|   if (!ctx) { | ||||
|     throw std::runtime_error("[load_gguf] gguf_init failed"); | ||||
| @@ -280,7 +277,7 @@ void append_kv_array( | ||||
| void save_gguf( | ||||
|     std::string file, | ||||
|     std::unordered_map<std::string, array> array_map, | ||||
|     std::unordered_map<std::string, MetaData> metadata /* = {} */) { | ||||
|     std::unordered_map<std::string, GGUFMetaData> metadata /* = {} */) { | ||||
|   // Add .gguf to file name if it is not there | ||||
|   if (file.length() < 5 || file.substr(file.length() - 5, 5) != ".gguf") { | ||||
|     file += ".gguf"; | ||||
|   | ||||
| @@ -93,7 +93,7 @@ Dtype dtype_from_safetensor_str(std::string str) { | ||||
| } | ||||
|  | ||||
| /** Load array from reader in safetensor format */ | ||||
| std::unordered_map<std::string, array> load_safetensors( | ||||
| SafetensorsLoad load_safetensors( | ||||
|     std::shared_ptr<io::Reader> in_stream, | ||||
|     StreamOrDevice s) { | ||||
|   //////////////////////////////////////////////////////// | ||||
| @@ -121,9 +121,12 @@ std::unordered_map<std::string, array> load_safetensors( | ||||
|   size_t offset = jsonHeaderLength + 8; | ||||
|   // Load the arrays using metadata | ||||
|   std::unordered_map<std::string, array> res; | ||||
|   std::unordered_map<std::string, std::string> metadata_map; | ||||
|   for (const auto& item : metadata.items()) { | ||||
|     if (item.key() == "__metadata__") { | ||||
|       // ignore metadata for now | ||||
|       for (const auto& meta_item : item.value().items()) { | ||||
|         metadata_map.insert({meta_item.key(), meta_item.value()}); | ||||
|       } | ||||
|       continue; | ||||
|     } | ||||
|     std::string dtype = item.value().at("dtype"); | ||||
| @@ -138,19 +141,18 @@ std::unordered_map<std::string, array> load_safetensors( | ||||
|         std::vector<array>{}); | ||||
|     res.insert({item.key(), loaded_array}); | ||||
|   } | ||||
|   return res; | ||||
|   return {res, metadata_map}; | ||||
| } | ||||
|  | ||||
| std::unordered_map<std::string, array> load_safetensors( | ||||
|     const std::string& file, | ||||
|     StreamOrDevice s) { | ||||
| SafetensorsLoad load_safetensors(const std::string& file, StreamOrDevice s) { | ||||
|   return load_safetensors(std::make_shared<io::FileReader>(file), s); | ||||
| } | ||||
|  | ||||
| /** Save array to out stream in .npy format */ | ||||
| void save_safetensors( | ||||
|     std::shared_ptr<io::Writer> out_stream, | ||||
|     std::unordered_map<std::string, array> a) { | ||||
|     std::unordered_map<std::string, array> a, | ||||
|     std::unordered_map<std::string, std::string> metadata /* = {} */) { | ||||
|   //////////////////////////////////////////////////////// | ||||
|   // Check file | ||||
|   if (!out_stream->good() || !out_stream->is_open()) { | ||||
| @@ -161,9 +163,11 @@ void save_safetensors( | ||||
|   //////////////////////////////////////////////////////// | ||||
|   // Check array map | ||||
|   json parent; | ||||
|   parent["__metadata__"] = json::object({ | ||||
|       {"format", "mlx"}, | ||||
|   }); | ||||
|   json _metadata; | ||||
|   for (auto& [key, value] : metadata) { | ||||
|     _metadata[key] = value; | ||||
|   } | ||||
|   parent["__metadata__"] = _metadata; | ||||
|   size_t offset = 0; | ||||
|   for (auto& [key, arr] : a) { | ||||
|     arr.eval(); | ||||
| @@ -204,7 +208,8 @@ void save_safetensors( | ||||
|  | ||||
| void save_safetensors( | ||||
|     const std::string& file_, | ||||
|     std::unordered_map<std::string, array> a) { | ||||
|     std::unordered_map<std::string, array> a, | ||||
|     std::unordered_map<std::string, std::string> metadata /* = {} */) { | ||||
|   // Open and check file | ||||
|   std::string file = file_; | ||||
|  | ||||
| @@ -214,7 +219,7 @@ void save_safetensors( | ||||
|     file += ".safetensors"; | ||||
|  | ||||
|   // Serialize array | ||||
|   save_safetensors(std::make_shared<io::FileWriter>(file), a); | ||||
|   save_safetensors(std::make_shared<io::FileWriter>(file), a, metadata); | ||||
| } | ||||
|  | ||||
| } // namespace mlx::core | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Diogo
					Diogo