Metadata support for safetensors (#639)

* metadata support for safetensors

* aliases making it alittle more readable

* addressing comments

* python binding tests
This commit is contained in:
Diogo
2024-02-08 22:33:15 -05:00
committed by GitHub
parent 221f8d3fc2
commit b57bd0488d
8 changed files with 108 additions and 69 deletions

View File

@@ -10,6 +10,14 @@
#include "mlx/stream.h"
namespace mlx::core {
using GGUFMetaData =
std::variant<std::monostate, array, std::string, std::vector<std::string>>;
using GGUFLoad = std::pair<
std::unordered_map<std::string, array>,
std::unordered_map<std::string, GGUFMetaData>>;
using SafetensorsLoad = std::pair<
std::unordered_map<std::string, array>,
std::unordered_map<std::string, std::string>>;
/** Save array to out stream in .npy format */
void save(std::shared_ptr<io::Writer> out_stream, array a);
@@ -24,32 +32,29 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
array load(const std::string& file, StreamOrDevice s = {});
/** Load array map from .safetensors file format */
std::unordered_map<std::string, array> load_safetensors(
SafetensorsLoad load_safetensors(
std::shared_ptr<io::Reader> in_stream,
StreamOrDevice s = {});
std::unordered_map<std::string, array> load_safetensors(
SafetensorsLoad load_safetensors(
const std::string& file,
StreamOrDevice s = {});
void save_safetensors(
std::shared_ptr<io::Writer> in_stream,
std::unordered_map<std::string, array>);
std::unordered_map<std::string, array>,
std::unordered_map<std::string, std::string> metadata = {});
void save_safetensors(
const std::string& file,
std::unordered_map<std::string, array>);
using MetaData =
std::variant<std::monostate, array, std::string, std::vector<std::string>>;
std::unordered_map<std::string, array>,
std::unordered_map<std::string, std::string> metadata = {});
/** Load array map and metadata from .gguf file format */
std::pair<
std::unordered_map<std::string, array>,
std::unordered_map<std::string, MetaData>>
load_gguf(const std::string& file, StreamOrDevice s = {});
GGUFLoad load_gguf(const std::string& file, StreamOrDevice s = {});
void save_gguf(
std::string file,
std::unordered_map<std::string, array> array_map,
std::unordered_map<std::string, MetaData> meta_data = {});
std::unordered_map<std::string, GGUFMetaData> meta_data = {});
} // namespace mlx::core

View File

@@ -82,7 +82,7 @@ void set_mx_value_from_gguf(
gguf_ctx* ctx,
uint32_t type,
gguf_value* val,
MetaData& value) {
GGUFMetaData& value) {
switch (type) {
case GGUF_VALUE_TYPE_UINT8:
value = array(val->uint8, uint8);
@@ -191,12 +191,12 @@ void set_mx_value_from_gguf(
}
}
std::unordered_map<std::string, MetaData> load_metadata(gguf_ctx* ctx) {
std::unordered_map<std::string, MetaData> metadata;
std::unordered_map<std::string, GGUFMetaData> load_metadata(gguf_ctx* ctx) {
std::unordered_map<std::string, GGUFMetaData> metadata;
gguf_key key;
while (gguf_get_key(ctx, &key)) {
std::string key_name = std::string(key.name, key.namelen);
auto& val = metadata.insert({key_name, MetaData{}}).first->second;
auto& val = metadata.insert({key_name, GGUFMetaData{}}).first->second;
set_mx_value_from_gguf(ctx, key.type, key.val, val);
}
return metadata;
@@ -230,10 +230,7 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
return array_map;
}
std::pair<
std::unordered_map<std::string, array>,
std::unordered_map<std::string, MetaData>>
load_gguf(const std::string& file, StreamOrDevice s) {
GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) {
gguf_ctx* ctx = gguf_open(file.c_str());
if (!ctx) {
throw std::runtime_error("[load_gguf] gguf_init failed");
@@ -280,7 +277,7 @@ void append_kv_array(
void save_gguf(
std::string file,
std::unordered_map<std::string, array> array_map,
std::unordered_map<std::string, MetaData> metadata /* = {} */) {
std::unordered_map<std::string, GGUFMetaData> metadata /* = {} */) {
// Add .gguf to file name if it is not there
if (file.length() < 5 || file.substr(file.length() - 5, 5) != ".gguf") {
file += ".gguf";

View File

@@ -93,7 +93,7 @@ Dtype dtype_from_safetensor_str(std::string str) {
}
/** Load array from reader in safetensor format */
std::unordered_map<std::string, array> load_safetensors(
SafetensorsLoad load_safetensors(
std::shared_ptr<io::Reader> in_stream,
StreamOrDevice s) {
////////////////////////////////////////////////////////
@@ -121,9 +121,12 @@ std::unordered_map<std::string, array> load_safetensors(
size_t offset = jsonHeaderLength + 8;
// Load the arrays using metadata
std::unordered_map<std::string, array> res;
std::unordered_map<std::string, std::string> metadata_map;
for (const auto& item : metadata.items()) {
if (item.key() == "__metadata__") {
// ignore metadata for now
for (const auto& meta_item : item.value().items()) {
metadata_map.insert({meta_item.key(), meta_item.value()});
}
continue;
}
std::string dtype = item.value().at("dtype");
@@ -138,19 +141,18 @@ std::unordered_map<std::string, array> load_safetensors(
std::vector<array>{});
res.insert({item.key(), loaded_array});
}
return res;
return {res, metadata_map};
}
std::unordered_map<std::string, array> load_safetensors(
const std::string& file,
StreamOrDevice s) {
SafetensorsLoad load_safetensors(const std::string& file, StreamOrDevice s) {
return load_safetensors(std::make_shared<io::FileReader>(file), s);
}
/** Save array to out stream in .npy format */
void save_safetensors(
std::shared_ptr<io::Writer> out_stream,
std::unordered_map<std::string, array> a) {
std::unordered_map<std::string, array> a,
std::unordered_map<std::string, std::string> metadata /* = {} */) {
////////////////////////////////////////////////////////
// Check file
if (!out_stream->good() || !out_stream->is_open()) {
@@ -161,9 +163,11 @@ void save_safetensors(
////////////////////////////////////////////////////////
// Check array map
json parent;
parent["__metadata__"] = json::object({
{"format", "mlx"},
});
json _metadata;
for (auto& [key, value] : metadata) {
_metadata[key] = value;
}
parent["__metadata__"] = _metadata;
size_t offset = 0;
for (auto& [key, arr] : a) {
arr.eval();
@@ -204,7 +208,8 @@ void save_safetensors(
void save_safetensors(
const std::string& file_,
std::unordered_map<std::string, array> a) {
std::unordered_map<std::string, array> a,
std::unordered_map<std::string, std::string> metadata /* = {} */) {
// Open and check file
std::string file = file_;
@@ -214,7 +219,7 @@ void save_safetensors(
file += ".safetensors";
// Serialize array
save_safetensors(std::make_shared<io::FileWriter>(file), a);
save_safetensors(std::make_shared<io::FileWriter>(file), a, metadata);
}
} // namespace mlx::core