Metadata support for safetensors (#639)

* metadata support for safetensors

* aliases making it alittle more readable

* addressing comments

* python binding tests
This commit is contained in:
Diogo 2024-02-08 22:33:15 -05:00 committed by GitHub
parent 221f8d3fc2
commit b57bd0488d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 108 additions and 69 deletions

View File

@ -10,6 +10,14 @@
#include "mlx/stream.h" #include "mlx/stream.h"
namespace mlx::core { namespace mlx::core {
using GGUFMetaData =
std::variant<std::monostate, array, std::string, std::vector<std::string>>;
using GGUFLoad = std::pair<
std::unordered_map<std::string, array>,
std::unordered_map<std::string, GGUFMetaData>>;
using SafetensorsLoad = std::pair<
std::unordered_map<std::string, array>,
std::unordered_map<std::string, std::string>>;
/** Save array to out stream in .npy format */ /** Save array to out stream in .npy format */
void save(std::shared_ptr<io::Writer> out_stream, array a); void save(std::shared_ptr<io::Writer> out_stream, array a);
@ -24,32 +32,29 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
array load(const std::string& file, StreamOrDevice s = {}); array load(const std::string& file, StreamOrDevice s = {});
/** Load array map from .safetensors file format */ /** Load array map from .safetensors file format */
std::unordered_map<std::string, array> load_safetensors( SafetensorsLoad load_safetensors(
std::shared_ptr<io::Reader> in_stream, std::shared_ptr<io::Reader> in_stream,
StreamOrDevice s = {}); StreamOrDevice s = {});
std::unordered_map<std::string, array> load_safetensors( SafetensorsLoad load_safetensors(
const std::string& file, const std::string& file,
StreamOrDevice s = {}); StreamOrDevice s = {});
void save_safetensors( void save_safetensors(
std::shared_ptr<io::Writer> in_stream, std::shared_ptr<io::Writer> in_stream,
std::unordered_map<std::string, array>); std::unordered_map<std::string, array>,
std::unordered_map<std::string, std::string> metadata = {});
void save_safetensors( void save_safetensors(
const std::string& file, const std::string& file,
std::unordered_map<std::string, array>); std::unordered_map<std::string, array>,
std::unordered_map<std::string, std::string> metadata = {});
using MetaData =
std::variant<std::monostate, array, std::string, std::vector<std::string>>;
/** Load array map and metadata from .gguf file format */ /** Load array map and metadata from .gguf file format */
std::pair<
std::unordered_map<std::string, array>, GGUFLoad load_gguf(const std::string& file, StreamOrDevice s = {});
std::unordered_map<std::string, MetaData>>
load_gguf(const std::string& file, StreamOrDevice s = {});
void save_gguf( void save_gguf(
std::string file, std::string file,
std::unordered_map<std::string, array> array_map, std::unordered_map<std::string, array> array_map,
std::unordered_map<std::string, MetaData> meta_data = {}); std::unordered_map<std::string, GGUFMetaData> meta_data = {});
} // namespace mlx::core } // namespace mlx::core

View File

@ -82,7 +82,7 @@ void set_mx_value_from_gguf(
gguf_ctx* ctx, gguf_ctx* ctx,
uint32_t type, uint32_t type,
gguf_value* val, gguf_value* val,
MetaData& value) { GGUFMetaData& value) {
switch (type) { switch (type) {
case GGUF_VALUE_TYPE_UINT8: case GGUF_VALUE_TYPE_UINT8:
value = array(val->uint8, uint8); value = array(val->uint8, uint8);
@ -191,12 +191,12 @@ void set_mx_value_from_gguf(
} }
} }
std::unordered_map<std::string, MetaData> load_metadata(gguf_ctx* ctx) { std::unordered_map<std::string, GGUFMetaData> load_metadata(gguf_ctx* ctx) {
std::unordered_map<std::string, MetaData> metadata; std::unordered_map<std::string, GGUFMetaData> metadata;
gguf_key key; gguf_key key;
while (gguf_get_key(ctx, &key)) { while (gguf_get_key(ctx, &key)) {
std::string key_name = std::string(key.name, key.namelen); std::string key_name = std::string(key.name, key.namelen);
auto& val = metadata.insert({key_name, MetaData{}}).first->second; auto& val = metadata.insert({key_name, GGUFMetaData{}}).first->second;
set_mx_value_from_gguf(ctx, key.type, key.val, val); set_mx_value_from_gguf(ctx, key.type, key.val, val);
} }
return metadata; return metadata;
@ -230,10 +230,7 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
return array_map; return array_map;
} }
std::pair< GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) {
std::unordered_map<std::string, array>,
std::unordered_map<std::string, MetaData>>
load_gguf(const std::string& file, StreamOrDevice s) {
gguf_ctx* ctx = gguf_open(file.c_str()); gguf_ctx* ctx = gguf_open(file.c_str());
if (!ctx) { if (!ctx) {
throw std::runtime_error("[load_gguf] gguf_init failed"); throw std::runtime_error("[load_gguf] gguf_init failed");
@ -280,7 +277,7 @@ void append_kv_array(
void save_gguf( void save_gguf(
std::string file, std::string file,
std::unordered_map<std::string, array> array_map, std::unordered_map<std::string, array> array_map,
std::unordered_map<std::string, MetaData> metadata /* = {} */) { std::unordered_map<std::string, GGUFMetaData> metadata /* = {} */) {
// Add .gguf to file name if it is not there // Add .gguf to file name if it is not there
if (file.length() < 5 || file.substr(file.length() - 5, 5) != ".gguf") { if (file.length() < 5 || file.substr(file.length() - 5, 5) != ".gguf") {
file += ".gguf"; file += ".gguf";

View File

@ -93,7 +93,7 @@ Dtype dtype_from_safetensor_str(std::string str) {
} }
/** Load array from reader in safetensor format */ /** Load array from reader in safetensor format */
std::unordered_map<std::string, array> load_safetensors( SafetensorsLoad load_safetensors(
std::shared_ptr<io::Reader> in_stream, std::shared_ptr<io::Reader> in_stream,
StreamOrDevice s) { StreamOrDevice s) {
//////////////////////////////////////////////////////// ////////////////////////////////////////////////////////
@ -121,9 +121,12 @@ std::unordered_map<std::string, array> load_safetensors(
size_t offset = jsonHeaderLength + 8; size_t offset = jsonHeaderLength + 8;
// Load the arrays using metadata // Load the arrays using metadata
std::unordered_map<std::string, array> res; std::unordered_map<std::string, array> res;
std::unordered_map<std::string, std::string> metadata_map;
for (const auto& item : metadata.items()) { for (const auto& item : metadata.items()) {
if (item.key() == "__metadata__") { if (item.key() == "__metadata__") {
// ignore metadata for now for (const auto& meta_item : item.value().items()) {
metadata_map.insert({meta_item.key(), meta_item.value()});
}
continue; continue;
} }
std::string dtype = item.value().at("dtype"); std::string dtype = item.value().at("dtype");
@ -138,19 +141,18 @@ std::unordered_map<std::string, array> load_safetensors(
std::vector<array>{}); std::vector<array>{});
res.insert({item.key(), loaded_array}); res.insert({item.key(), loaded_array});
} }
return res; return {res, metadata_map};
} }
std::unordered_map<std::string, array> load_safetensors( SafetensorsLoad load_safetensors(const std::string& file, StreamOrDevice s) {
const std::string& file,
StreamOrDevice s) {
return load_safetensors(std::make_shared<io::FileReader>(file), s); return load_safetensors(std::make_shared<io::FileReader>(file), s);
} }
/** Save array to out stream in .npy format */ /** Save array to out stream in .npy format */
void save_safetensors( void save_safetensors(
std::shared_ptr<io::Writer> out_stream, std::shared_ptr<io::Writer> out_stream,
std::unordered_map<std::string, array> a) { std::unordered_map<std::string, array> a,
std::unordered_map<std::string, std::string> metadata /* = {} */) {
//////////////////////////////////////////////////////// ////////////////////////////////////////////////////////
// Check file // Check file
if (!out_stream->good() || !out_stream->is_open()) { if (!out_stream->good() || !out_stream->is_open()) {
@ -161,9 +163,11 @@ void save_safetensors(
//////////////////////////////////////////////////////// ////////////////////////////////////////////////////////
// Check array map // Check array map
json parent; json parent;
parent["__metadata__"] = json::object({ json _metadata;
{"format", "mlx"}, for (auto& [key, value] : metadata) {
}); _metadata[key] = value;
}
parent["__metadata__"] = _metadata;
size_t offset = 0; size_t offset = 0;
for (auto& [key, arr] : a) { for (auto& [key, arr] : a) {
arr.eval(); arr.eval();
@ -204,7 +208,8 @@ void save_safetensors(
void save_safetensors( void save_safetensors(
const std::string& file_, const std::string& file_,
std::unordered_map<std::string, array> a) { std::unordered_map<std::string, array> a,
std::unordered_map<std::string, std::string> metadata /* = {} */) {
// Open and check file // Open and check file
std::string file = file_; std::string file = file_;
@ -214,7 +219,7 @@ void save_safetensors(
file += ".safetensors"; file += ".safetensors";
// Serialize array // Serialize array
save_safetensors(std::make_shared<io::FileWriter>(file), a); save_safetensors(std::make_shared<io::FileWriter>(file), a, metadata);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -160,31 +160,29 @@ class PyFileReader : public io::Reader {
py::object tell_func_; py::object tell_func_;
}; };
std::unordered_map<std::string, array> mlx_load_safetensor_helper( std::pair<
py::object file, std::unordered_map<std::string, array>,
StreamOrDevice s) { std::unordered_map<std::string, std::string>>
mlx_load_safetensor_helper(py::object file, StreamOrDevice s) {
if (py::isinstance<py::str>(file)) { // Assume .safetensors file path string if (py::isinstance<py::str>(file)) { // Assume .safetensors file path string
return load_safetensors(py::cast<std::string>(file), s); return load_safetensors(py::cast<std::string>(file), s);
} else if (is_istream_object(file)) { } else if (is_istream_object(file)) {
// If we don't own the stream and it was passed to us, eval immediately // If we don't own the stream and it was passed to us, eval immediately
auto arr = load_safetensors(std::make_shared<PyFileReader>(file), s); auto res = load_safetensors(std::make_shared<PyFileReader>(file), s);
{ {
py::gil_scoped_release gil; py::gil_scoped_release gil;
for (auto& [key, arr] : arr) { for (auto& [key, arr] : std::get<0>(res)) {
arr.eval(); arr.eval();
} }
} }
return arr; return res;
} }
throw std::invalid_argument( throw std::invalid_argument(
"[load_safetensors] Input must be a file-like object, or string"); "[load_safetensors] Input must be a file-like object, or string");
} }
std::pair< GGUFLoad mlx_load_gguf_helper(py::object file, StreamOrDevice s) {
std::unordered_map<std::string, array>,
std::unordered_map<std::string, MetaData>>
mlx_load_gguf_helper(py::object file, StreamOrDevice s) {
if (py::isinstance<py::str>(file)) { // Assume .gguf file path string if (py::isinstance<py::str>(file)) { // Assume .gguf file path string
return load_gguf(py::cast<std::string>(file), s); return load_gguf(py::cast<std::string>(file), s);
} }
@ -274,12 +272,16 @@ LoadOutputTypes mlx_load_helper(
format.emplace(fname.substr(ext + 1)); format.emplace(fname.substr(ext + 1));
} }
if (return_metadata && format.value() != "gguf") { if (return_metadata && (format.value() == "npy" || format.value() == "npz")) {
throw std::invalid_argument( throw std::invalid_argument(
"[load] metadata not supported for format " + format.value()); "[load] metadata not supported for format " + format.value());
} }
if (format.value() == "safetensors") { if (format.value() == "safetensors") {
return mlx_load_safetensor_helper(file, s); auto [dict, metadata] = mlx_load_safetensor_helper(file, s);
if (return_metadata) {
return std::make_pair(dict, metadata);
}
return dict;
} else if (format.value() == "npz") { } else if (format.value() == "npz") {
return mlx_load_npz_helper(file, s); return mlx_load_npz_helper(file, s);
} else if (format.value() == "npy") { } else if (format.value() == "npy") {
@ -444,18 +446,33 @@ void mlx_savez_helper(
return; return;
} }
void mlx_save_safetensor_helper(py::object file, py::dict d) { void mlx_save_safetensor_helper(
py::object file,
py::dict d,
std::optional<py::dict> m) {
std::unordered_map<std::string, std::string> metadata_map;
if (m) {
try {
metadata_map =
m.value().cast<std::unordered_map<std::string, std::string>>();
} catch (const py::cast_error& e) {
throw std::invalid_argument(
"[save_safetensors] Metadata must be a dictionary with string keys and values");
}
} else {
metadata_map = std::unordered_map<std::string, std::string>();
}
auto arrays_map = d.cast<std::unordered_map<std::string, array>>(); auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
if (py::isinstance<py::str>(file)) { if (py::isinstance<py::str>(file)) {
{ {
py::gil_scoped_release nogil; py::gil_scoped_release nogil;
save_safetensors(py::cast<std::string>(file), arrays_map); save_safetensors(py::cast<std::string>(file), arrays_map, metadata_map);
} }
} else if (is_ostream_object(file)) { } else if (is_ostream_object(file)) {
auto writer = std::make_shared<PyFileWriter>(file); auto writer = std::make_shared<PyFileWriter>(file);
{ {
py::gil_scoped_release nogil; py::gil_scoped_release nogil;
save_safetensors(writer, arrays_map); save_safetensors(writer, arrays_map, metadata_map);
} }
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
@ -471,7 +488,7 @@ void mlx_save_gguf_helper(
if (py::isinstance<py::str>(file)) { if (py::isinstance<py::str>(file)) {
if (m) { if (m) {
auto metadata_map = auto metadata_map =
m.value().cast<std::unordered_map<std::string, MetaData>>(); m.value().cast<std::unordered_map<std::string, GGUFMetaData>>();
{ {
py::gil_scoped_release nogil; py::gil_scoped_release nogil;
save_gguf(py::cast<std::string>(file), arrays_map, metadata_map); save_gguf(py::cast<std::string>(file), arrays_map, metadata_map);

View File

@ -15,19 +15,17 @@ using namespace mlx::core;
using LoadOutputTypes = std::variant< using LoadOutputTypes = std::variant<
array, array,
std::unordered_map<std::string, array>, std::unordered_map<std::string, array>,
std::pair< SafetensorsLoad,
std::unordered_map<std::string, array>, GGUFLoad>;
std::unordered_map<std::string, MetaData>>>;
std::unordered_map<std::string, array> mlx_load_safetensor_helper( SafetensorsLoad mlx_load_safetensor_helper(py::object file, StreamOrDevice s);
void mlx_save_safetensor_helper(
py::object file, py::object file,
StreamOrDevice s); py::dict d,
void mlx_save_safetensor_helper(py::object file, py::dict d); std::optional<py::dict> m);
GGUFLoad mlx_load_gguf_helper(py::object file, StreamOrDevice s);
std::pair<
std::unordered_map<std::string, array>,
std::unordered_map<std::string, MetaData>>
mlx_load_gguf_helper(py::object file, StreamOrDevice s);
void mlx_save_gguf_helper( void mlx_save_gguf_helper(
py::object file, py::object file,
py::dict d, py::dict d,

View File

@ -3214,8 +3214,9 @@ void init_ops(py::module_& m) {
&mlx_save_safetensor_helper, &mlx_save_safetensor_helper,
"file"_a, "file"_a,
"arrays"_a, "arrays"_a,
"metadata"_a = none,
R"pbdoc( R"pbdoc(
save_safetensors(file: str, arrays: Dict[str, array]) save_safetensors(file: str, arrays: Dict[str, array], metadata: Optional[Dict[str, str]] = None)
Save array(s) to a binary file in ``.safetensors`` format. Save array(s) to a binary file in ``.safetensors`` format.
@ -3225,6 +3226,7 @@ void init_ops(py::module_& m) {
Args: Args:
file (file, str): File in which the array is saved. file (file, str): File in which the array is saved.
arrays (dict(str, array)): The dictionary of names to arrays to be saved. arrays (dict(str, array)): The dictionary of names to arrays to be saved.
metadata (dict(str, str), optional): The dictionary of metadata to be saved.
)pbdoc"); )pbdoc");
m.def( m.def(
"save_gguf", "save_gguf",

View File

@ -66,6 +66,15 @@ class TestLoad(mlx_tests.MLXTestCase):
def test_save_and_load_safetensors(self): def test_save_and_load_safetensors(self):
if not os.path.isdir(self.test_dir): if not os.path.isdir(self.test_dir):
os.mkdir(self.test_dir) os.mkdir(self.test_dir)
with self.assertRaises(Exception):
mx.save_safetensors("test", {"a": mx.ones((4, 4))}, {"testing": 0})
mx.save_safetensors(
"test", {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"}
)
res = mx.load("test.safetensors", return_metadata=True)
self.assertEqual(len(res), 2)
self.assertEqual(res[1], {"testing": "test", "format": "mlx"})
for dt in self.dtypes + ["bfloat16"]: for dt in self.dtypes + ["bfloat16"]:
with self.subTest(dtype=dt): with self.subTest(dtype=dt):

View File

@ -19,8 +19,14 @@ TEST_CASE("test save_safetensors") {
auto map = std::unordered_map<std::string, array>(); auto map = std::unordered_map<std::string, array>();
map.insert({"test", array({1.0, 2.0, 3.0, 4.0})}); map.insert({"test", array({1.0, 2.0, 3.0, 4.0})});
map.insert({"test2", ones({2, 2})}); map.insert({"test2", ones({2, 2})});
save_safetensors(file_path, map); auto _metadata = std::unordered_map<std::string, std::string>();
auto dict = load_safetensors(file_path); _metadata.insert({"test", "test"});
_metadata.insert({"test2", "test2"});
save_safetensors(file_path, map, _metadata);
auto [dict, metadata] = load_safetensors(file_path);
CHECK_EQ(metadata, _metadata);
CHECK_EQ(dict.size(), 2); CHECK_EQ(dict.size(), 2);
CHECK_EQ(dict.count("test"), 1); CHECK_EQ(dict.count("test"), 1);
CHECK_EQ(dict.count("test2"), 1); CHECK_EQ(dict.count("test2"), 1);
@ -55,7 +61,7 @@ TEST_CASE("test gguf") {
} }
// Test saving and loading string metadata // Test saving and loading string metadata
std::unordered_map<std::string, MetaData> original_metadata; std::unordered_map<std::string, GGUFMetaData> original_metadata;
original_metadata.insert({"test_str", "my string"}); original_metadata.insert({"test_str", "my string"});
save_gguf(file_path, original_weights, original_metadata); save_gguf(file_path, original_weights, original_metadata);
@ -97,7 +103,7 @@ TEST_CASE("test gguf metadata") {
// Scalar array // Scalar array
{ {
std::unordered_map<std::string, MetaData> original_metadata; std::unordered_map<std::string, GGUFMetaData> original_metadata;
original_metadata.insert({"test_arr", array(1.0)}); original_metadata.insert({"test_arr", array(1.0)});
save_gguf(file_path, original_weights, original_metadata); save_gguf(file_path, original_weights, original_metadata);
@ -111,7 +117,7 @@ TEST_CASE("test gguf metadata") {
// 1D Array // 1D Array
{ {
std::unordered_map<std::string, MetaData> original_metadata; std::unordered_map<std::string, GGUFMetaData> original_metadata;
auto arr = array({1.0, 2.0}); auto arr = array({1.0, 2.0});
original_metadata.insert({"test_arr", arr}); original_metadata.insert({"test_arr", arr});
save_gguf(file_path, original_weights, original_metadata); save_gguf(file_path, original_weights, original_metadata);
@ -138,21 +144,21 @@ TEST_CASE("test gguf metadata") {
// > 1D array throws // > 1D array throws
{ {
std::unordered_map<std::string, MetaData> original_metadata; std::unordered_map<std::string, GGUFMetaData> original_metadata;
original_metadata.insert({"test_arr", array({1.0}, {1, 1})}); original_metadata.insert({"test_arr", array({1.0}, {1, 1})});
CHECK_THROWS(save_gguf(file_path, original_weights, original_metadata)); CHECK_THROWS(save_gguf(file_path, original_weights, original_metadata));
} }
// empty array throws // empty array throws
{ {
std::unordered_map<std::string, MetaData> original_metadata; std::unordered_map<std::string, GGUFMetaData> original_metadata;
original_metadata.insert({"test_arr", array({})}); original_metadata.insert({"test_arr", array({})});
CHECK_THROWS(save_gguf(file_path, original_weights, original_metadata)); CHECK_THROWS(save_gguf(file_path, original_weights, original_metadata));
} }
// vector of string // vector of string
{ {
std::unordered_map<std::string, MetaData> original_metadata; std::unordered_map<std::string, GGUFMetaData> original_metadata;
std::vector<std::string> data = {"data1", "data2", "data1234"}; std::vector<std::string> data = {"data1", "data2", "data1234"};
original_metadata.insert({"meta", data}); original_metadata.insert({"meta", data});
save_gguf(file_path, original_weights, original_metadata); save_gguf(file_path, original_weights, original_metadata);
@ -169,7 +175,7 @@ TEST_CASE("test gguf metadata") {
// vector of string, string, scalar, and array // vector of string, string, scalar, and array
{ {
std::unordered_map<std::string, MetaData> original_metadata; std::unordered_map<std::string, GGUFMetaData> original_metadata;
std::vector<std::string> data = {"data1", "data2", "data1234"}; std::vector<std::string> data = {"data1", "data2", "data1234"};
original_metadata.insert({"meta1", data}); original_metadata.insert({"meta1", data});
original_metadata.insert({"meta2", array(2.5)}); original_metadata.insert({"meta2", array(2.5)});