mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
Metadata support for safetensors (#639)
* metadata support for safetensors * aliases making it alittle more readable * addressing comments * python binding tests
This commit is contained in:
29
mlx/io.h
29
mlx/io.h
@@ -10,6 +10,14 @@
|
||||
#include "mlx/stream.h"
|
||||
|
||||
namespace mlx::core {
|
||||
using GGUFMetaData =
|
||||
std::variant<std::monostate, array, std::string, std::vector<std::string>>;
|
||||
using GGUFLoad = std::pair<
|
||||
std::unordered_map<std::string, array>,
|
||||
std::unordered_map<std::string, GGUFMetaData>>;
|
||||
using SafetensorsLoad = std::pair<
|
||||
std::unordered_map<std::string, array>,
|
||||
std::unordered_map<std::string, std::string>>;
|
||||
|
||||
/** Save array to out stream in .npy format */
|
||||
void save(std::shared_ptr<io::Writer> out_stream, array a);
|
||||
@@ -24,32 +32,29 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
|
||||
array load(const std::string& file, StreamOrDevice s = {});
|
||||
|
||||
/** Load array map from .safetensors file format */
|
||||
std::unordered_map<std::string, array> load_safetensors(
|
||||
SafetensorsLoad load_safetensors(
|
||||
std::shared_ptr<io::Reader> in_stream,
|
||||
StreamOrDevice s = {});
|
||||
std::unordered_map<std::string, array> load_safetensors(
|
||||
SafetensorsLoad load_safetensors(
|
||||
const std::string& file,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
void save_safetensors(
|
||||
std::shared_ptr<io::Writer> in_stream,
|
||||
std::unordered_map<std::string, array>);
|
||||
std::unordered_map<std::string, array>,
|
||||
std::unordered_map<std::string, std::string> metadata = {});
|
||||
void save_safetensors(
|
||||
const std::string& file,
|
||||
std::unordered_map<std::string, array>);
|
||||
|
||||
using MetaData =
|
||||
std::variant<std::monostate, array, std::string, std::vector<std::string>>;
|
||||
std::unordered_map<std::string, array>,
|
||||
std::unordered_map<std::string, std::string> metadata = {});
|
||||
|
||||
/** Load array map and metadata from .gguf file format */
|
||||
std::pair<
|
||||
std::unordered_map<std::string, array>,
|
||||
std::unordered_map<std::string, MetaData>>
|
||||
load_gguf(const std::string& file, StreamOrDevice s = {});
|
||||
|
||||
GGUFLoad load_gguf(const std::string& file, StreamOrDevice s = {});
|
||||
|
||||
void save_gguf(
|
||||
std::string file,
|
||||
std::unordered_map<std::string, array> array_map,
|
||||
std::unordered_map<std::string, MetaData> meta_data = {});
|
||||
std::unordered_map<std::string, GGUFMetaData> meta_data = {});
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -82,7 +82,7 @@ void set_mx_value_from_gguf(
|
||||
gguf_ctx* ctx,
|
||||
uint32_t type,
|
||||
gguf_value* val,
|
||||
MetaData& value) {
|
||||
GGUFMetaData& value) {
|
||||
switch (type) {
|
||||
case GGUF_VALUE_TYPE_UINT8:
|
||||
value = array(val->uint8, uint8);
|
||||
@@ -191,12 +191,12 @@ void set_mx_value_from_gguf(
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, MetaData> load_metadata(gguf_ctx* ctx) {
|
||||
std::unordered_map<std::string, MetaData> metadata;
|
||||
std::unordered_map<std::string, GGUFMetaData> load_metadata(gguf_ctx* ctx) {
|
||||
std::unordered_map<std::string, GGUFMetaData> metadata;
|
||||
gguf_key key;
|
||||
while (gguf_get_key(ctx, &key)) {
|
||||
std::string key_name = std::string(key.name, key.namelen);
|
||||
auto& val = metadata.insert({key_name, MetaData{}}).first->second;
|
||||
auto& val = metadata.insert({key_name, GGUFMetaData{}}).first->second;
|
||||
set_mx_value_from_gguf(ctx, key.type, key.val, val);
|
||||
}
|
||||
return metadata;
|
||||
@@ -230,10 +230,7 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
|
||||
return array_map;
|
||||
}
|
||||
|
||||
std::pair<
|
||||
std::unordered_map<std::string, array>,
|
||||
std::unordered_map<std::string, MetaData>>
|
||||
load_gguf(const std::string& file, StreamOrDevice s) {
|
||||
GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) {
|
||||
gguf_ctx* ctx = gguf_open(file.c_str());
|
||||
if (!ctx) {
|
||||
throw std::runtime_error("[load_gguf] gguf_init failed");
|
||||
@@ -280,7 +277,7 @@ void append_kv_array(
|
||||
void save_gguf(
|
||||
std::string file,
|
||||
std::unordered_map<std::string, array> array_map,
|
||||
std::unordered_map<std::string, MetaData> metadata /* = {} */) {
|
||||
std::unordered_map<std::string, GGUFMetaData> metadata /* = {} */) {
|
||||
// Add .gguf to file name if it is not there
|
||||
if (file.length() < 5 || file.substr(file.length() - 5, 5) != ".gguf") {
|
||||
file += ".gguf";
|
||||
|
@@ -93,7 +93,7 @@ Dtype dtype_from_safetensor_str(std::string str) {
|
||||
}
|
||||
|
||||
/** Load array from reader in safetensor format */
|
||||
std::unordered_map<std::string, array> load_safetensors(
|
||||
SafetensorsLoad load_safetensors(
|
||||
std::shared_ptr<io::Reader> in_stream,
|
||||
StreamOrDevice s) {
|
||||
////////////////////////////////////////////////////////
|
||||
@@ -121,9 +121,12 @@ std::unordered_map<std::string, array> load_safetensors(
|
||||
size_t offset = jsonHeaderLength + 8;
|
||||
// Load the arrays using metadata
|
||||
std::unordered_map<std::string, array> res;
|
||||
std::unordered_map<std::string, std::string> metadata_map;
|
||||
for (const auto& item : metadata.items()) {
|
||||
if (item.key() == "__metadata__") {
|
||||
// ignore metadata for now
|
||||
for (const auto& meta_item : item.value().items()) {
|
||||
metadata_map.insert({meta_item.key(), meta_item.value()});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
std::string dtype = item.value().at("dtype");
|
||||
@@ -138,19 +141,18 @@ std::unordered_map<std::string, array> load_safetensors(
|
||||
std::vector<array>{});
|
||||
res.insert({item.key(), loaded_array});
|
||||
}
|
||||
return res;
|
||||
return {res, metadata_map};
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, array> load_safetensors(
|
||||
const std::string& file,
|
||||
StreamOrDevice s) {
|
||||
SafetensorsLoad load_safetensors(const std::string& file, StreamOrDevice s) {
|
||||
return load_safetensors(std::make_shared<io::FileReader>(file), s);
|
||||
}
|
||||
|
||||
/** Save array to out stream in .npy format */
|
||||
void save_safetensors(
|
||||
std::shared_ptr<io::Writer> out_stream,
|
||||
std::unordered_map<std::string, array> a) {
|
||||
std::unordered_map<std::string, array> a,
|
||||
std::unordered_map<std::string, std::string> metadata /* = {} */) {
|
||||
////////////////////////////////////////////////////////
|
||||
// Check file
|
||||
if (!out_stream->good() || !out_stream->is_open()) {
|
||||
@@ -161,9 +163,11 @@ void save_safetensors(
|
||||
////////////////////////////////////////////////////////
|
||||
// Check array map
|
||||
json parent;
|
||||
parent["__metadata__"] = json::object({
|
||||
{"format", "mlx"},
|
||||
});
|
||||
json _metadata;
|
||||
for (auto& [key, value] : metadata) {
|
||||
_metadata[key] = value;
|
||||
}
|
||||
parent["__metadata__"] = _metadata;
|
||||
size_t offset = 0;
|
||||
for (auto& [key, arr] : a) {
|
||||
arr.eval();
|
||||
@@ -204,7 +208,8 @@ void save_safetensors(
|
||||
|
||||
void save_safetensors(
|
||||
const std::string& file_,
|
||||
std::unordered_map<std::string, array> a) {
|
||||
std::unordered_map<std::string, array> a,
|
||||
std::unordered_map<std::string, std::string> metadata /* = {} */) {
|
||||
// Open and check file
|
||||
std::string file = file_;
|
||||
|
||||
@@ -214,7 +219,7 @@ void save_safetensors(
|
||||
file += ".safetensors";
|
||||
|
||||
// Serialize array
|
||||
save_safetensors(std::make_shared<io::FileWriter>(file), a);
|
||||
save_safetensors(std::make_shared<io::FileWriter>(file), a, metadata);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
Reference in New Issue
Block a user