mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-21 16:51:15 +08:00
Metadata support for safetensors (#639)
* metadata support for safetensors * aliases making it alittle more readable * addressing comments * python binding tests
This commit is contained in:
parent
221f8d3fc2
commit
b57bd0488d
29
mlx/io.h
29
mlx/io.h
@ -10,6 +10,14 @@
|
|||||||
#include "mlx/stream.h"
|
#include "mlx/stream.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
using GGUFMetaData =
|
||||||
|
std::variant<std::monostate, array, std::string, std::vector<std::string>>;
|
||||||
|
using GGUFLoad = std::pair<
|
||||||
|
std::unordered_map<std::string, array>,
|
||||||
|
std::unordered_map<std::string, GGUFMetaData>>;
|
||||||
|
using SafetensorsLoad = std::pair<
|
||||||
|
std::unordered_map<std::string, array>,
|
||||||
|
std::unordered_map<std::string, std::string>>;
|
||||||
|
|
||||||
/** Save array to out stream in .npy format */
|
/** Save array to out stream in .npy format */
|
||||||
void save(std::shared_ptr<io::Writer> out_stream, array a);
|
void save(std::shared_ptr<io::Writer> out_stream, array a);
|
||||||
@ -24,32 +32,29 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
|
|||||||
array load(const std::string& file, StreamOrDevice s = {});
|
array load(const std::string& file, StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Load array map from .safetensors file format */
|
/** Load array map from .safetensors file format */
|
||||||
std::unordered_map<std::string, array> load_safetensors(
|
SafetensorsLoad load_safetensors(
|
||||||
std::shared_ptr<io::Reader> in_stream,
|
std::shared_ptr<io::Reader> in_stream,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
std::unordered_map<std::string, array> load_safetensors(
|
SafetensorsLoad load_safetensors(
|
||||||
const std::string& file,
|
const std::string& file,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
void save_safetensors(
|
void save_safetensors(
|
||||||
std::shared_ptr<io::Writer> in_stream,
|
std::shared_ptr<io::Writer> in_stream,
|
||||||
std::unordered_map<std::string, array>);
|
std::unordered_map<std::string, array>,
|
||||||
|
std::unordered_map<std::string, std::string> metadata = {});
|
||||||
void save_safetensors(
|
void save_safetensors(
|
||||||
const std::string& file,
|
const std::string& file,
|
||||||
std::unordered_map<std::string, array>);
|
std::unordered_map<std::string, array>,
|
||||||
|
std::unordered_map<std::string, std::string> metadata = {});
|
||||||
using MetaData =
|
|
||||||
std::variant<std::monostate, array, std::string, std::vector<std::string>>;
|
|
||||||
|
|
||||||
/** Load array map and metadata from .gguf file format */
|
/** Load array map and metadata from .gguf file format */
|
||||||
std::pair<
|
|
||||||
std::unordered_map<std::string, array>,
|
GGUFLoad load_gguf(const std::string& file, StreamOrDevice s = {});
|
||||||
std::unordered_map<std::string, MetaData>>
|
|
||||||
load_gguf(const std::string& file, StreamOrDevice s = {});
|
|
||||||
|
|
||||||
void save_gguf(
|
void save_gguf(
|
||||||
std::string file,
|
std::string file,
|
||||||
std::unordered_map<std::string, array> array_map,
|
std::unordered_map<std::string, array> array_map,
|
||||||
std::unordered_map<std::string, MetaData> meta_data = {});
|
std::unordered_map<std::string, GGUFMetaData> meta_data = {});
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -82,7 +82,7 @@ void set_mx_value_from_gguf(
|
|||||||
gguf_ctx* ctx,
|
gguf_ctx* ctx,
|
||||||
uint32_t type,
|
uint32_t type,
|
||||||
gguf_value* val,
|
gguf_value* val,
|
||||||
MetaData& value) {
|
GGUFMetaData& value) {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case GGUF_VALUE_TYPE_UINT8:
|
case GGUF_VALUE_TYPE_UINT8:
|
||||||
value = array(val->uint8, uint8);
|
value = array(val->uint8, uint8);
|
||||||
@ -191,12 +191,12 @@ void set_mx_value_from_gguf(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unordered_map<std::string, MetaData> load_metadata(gguf_ctx* ctx) {
|
std::unordered_map<std::string, GGUFMetaData> load_metadata(gguf_ctx* ctx) {
|
||||||
std::unordered_map<std::string, MetaData> metadata;
|
std::unordered_map<std::string, GGUFMetaData> metadata;
|
||||||
gguf_key key;
|
gguf_key key;
|
||||||
while (gguf_get_key(ctx, &key)) {
|
while (gguf_get_key(ctx, &key)) {
|
||||||
std::string key_name = std::string(key.name, key.namelen);
|
std::string key_name = std::string(key.name, key.namelen);
|
||||||
auto& val = metadata.insert({key_name, MetaData{}}).first->second;
|
auto& val = metadata.insert({key_name, GGUFMetaData{}}).first->second;
|
||||||
set_mx_value_from_gguf(ctx, key.type, key.val, val);
|
set_mx_value_from_gguf(ctx, key.type, key.val, val);
|
||||||
}
|
}
|
||||||
return metadata;
|
return metadata;
|
||||||
@ -230,10 +230,7 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
|
|||||||
return array_map;
|
return array_map;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<
|
GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) {
|
||||||
std::unordered_map<std::string, array>,
|
|
||||||
std::unordered_map<std::string, MetaData>>
|
|
||||||
load_gguf(const std::string& file, StreamOrDevice s) {
|
|
||||||
gguf_ctx* ctx = gguf_open(file.c_str());
|
gguf_ctx* ctx = gguf_open(file.c_str());
|
||||||
if (!ctx) {
|
if (!ctx) {
|
||||||
throw std::runtime_error("[load_gguf] gguf_init failed");
|
throw std::runtime_error("[load_gguf] gguf_init failed");
|
||||||
@ -280,7 +277,7 @@ void append_kv_array(
|
|||||||
void save_gguf(
|
void save_gguf(
|
||||||
std::string file,
|
std::string file,
|
||||||
std::unordered_map<std::string, array> array_map,
|
std::unordered_map<std::string, array> array_map,
|
||||||
std::unordered_map<std::string, MetaData> metadata /* = {} */) {
|
std::unordered_map<std::string, GGUFMetaData> metadata /* = {} */) {
|
||||||
// Add .gguf to file name if it is not there
|
// Add .gguf to file name if it is not there
|
||||||
if (file.length() < 5 || file.substr(file.length() - 5, 5) != ".gguf") {
|
if (file.length() < 5 || file.substr(file.length() - 5, 5) != ".gguf") {
|
||||||
file += ".gguf";
|
file += ".gguf";
|
||||||
|
@ -93,7 +93,7 @@ Dtype dtype_from_safetensor_str(std::string str) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/** Load array from reader in safetensor format */
|
/** Load array from reader in safetensor format */
|
||||||
std::unordered_map<std::string, array> load_safetensors(
|
SafetensorsLoad load_safetensors(
|
||||||
std::shared_ptr<io::Reader> in_stream,
|
std::shared_ptr<io::Reader> in_stream,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////
|
||||||
@ -121,9 +121,12 @@ std::unordered_map<std::string, array> load_safetensors(
|
|||||||
size_t offset = jsonHeaderLength + 8;
|
size_t offset = jsonHeaderLength + 8;
|
||||||
// Load the arrays using metadata
|
// Load the arrays using metadata
|
||||||
std::unordered_map<std::string, array> res;
|
std::unordered_map<std::string, array> res;
|
||||||
|
std::unordered_map<std::string, std::string> metadata_map;
|
||||||
for (const auto& item : metadata.items()) {
|
for (const auto& item : metadata.items()) {
|
||||||
if (item.key() == "__metadata__") {
|
if (item.key() == "__metadata__") {
|
||||||
// ignore metadata for now
|
for (const auto& meta_item : item.value().items()) {
|
||||||
|
metadata_map.insert({meta_item.key(), meta_item.value()});
|
||||||
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
std::string dtype = item.value().at("dtype");
|
std::string dtype = item.value().at("dtype");
|
||||||
@ -138,19 +141,18 @@ std::unordered_map<std::string, array> load_safetensors(
|
|||||||
std::vector<array>{});
|
std::vector<array>{});
|
||||||
res.insert({item.key(), loaded_array});
|
res.insert({item.key(), loaded_array});
|
||||||
}
|
}
|
||||||
return res;
|
return {res, metadata_map};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unordered_map<std::string, array> load_safetensors(
|
SafetensorsLoad load_safetensors(const std::string& file, StreamOrDevice s) {
|
||||||
const std::string& file,
|
|
||||||
StreamOrDevice s) {
|
|
||||||
return load_safetensors(std::make_shared<io::FileReader>(file), s);
|
return load_safetensors(std::make_shared<io::FileReader>(file), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Save array to out stream in .npy format */
|
/** Save array to out stream in .npy format */
|
||||||
void save_safetensors(
|
void save_safetensors(
|
||||||
std::shared_ptr<io::Writer> out_stream,
|
std::shared_ptr<io::Writer> out_stream,
|
||||||
std::unordered_map<std::string, array> a) {
|
std::unordered_map<std::string, array> a,
|
||||||
|
std::unordered_map<std::string, std::string> metadata /* = {} */) {
|
||||||
////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////
|
||||||
// Check file
|
// Check file
|
||||||
if (!out_stream->good() || !out_stream->is_open()) {
|
if (!out_stream->good() || !out_stream->is_open()) {
|
||||||
@ -161,9 +163,11 @@ void save_safetensors(
|
|||||||
////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////
|
||||||
// Check array map
|
// Check array map
|
||||||
json parent;
|
json parent;
|
||||||
parent["__metadata__"] = json::object({
|
json _metadata;
|
||||||
{"format", "mlx"},
|
for (auto& [key, value] : metadata) {
|
||||||
});
|
_metadata[key] = value;
|
||||||
|
}
|
||||||
|
parent["__metadata__"] = _metadata;
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
for (auto& [key, arr] : a) {
|
for (auto& [key, arr] : a) {
|
||||||
arr.eval();
|
arr.eval();
|
||||||
@ -204,7 +208,8 @@ void save_safetensors(
|
|||||||
|
|
||||||
void save_safetensors(
|
void save_safetensors(
|
||||||
const std::string& file_,
|
const std::string& file_,
|
||||||
std::unordered_map<std::string, array> a) {
|
std::unordered_map<std::string, array> a,
|
||||||
|
std::unordered_map<std::string, std::string> metadata /* = {} */) {
|
||||||
// Open and check file
|
// Open and check file
|
||||||
std::string file = file_;
|
std::string file = file_;
|
||||||
|
|
||||||
@ -214,7 +219,7 @@ void save_safetensors(
|
|||||||
file += ".safetensors";
|
file += ".safetensors";
|
||||||
|
|
||||||
// Serialize array
|
// Serialize array
|
||||||
save_safetensors(std::make_shared<io::FileWriter>(file), a);
|
save_safetensors(std::make_shared<io::FileWriter>(file), a, metadata);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -160,31 +160,29 @@ class PyFileReader : public io::Reader {
|
|||||||
py::object tell_func_;
|
py::object tell_func_;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::unordered_map<std::string, array> mlx_load_safetensor_helper(
|
std::pair<
|
||||||
py::object file,
|
std::unordered_map<std::string, array>,
|
||||||
StreamOrDevice s) {
|
std::unordered_map<std::string, std::string>>
|
||||||
|
mlx_load_safetensor_helper(py::object file, StreamOrDevice s) {
|
||||||
if (py::isinstance<py::str>(file)) { // Assume .safetensors file path string
|
if (py::isinstance<py::str>(file)) { // Assume .safetensors file path string
|
||||||
return load_safetensors(py::cast<std::string>(file), s);
|
return load_safetensors(py::cast<std::string>(file), s);
|
||||||
} else if (is_istream_object(file)) {
|
} else if (is_istream_object(file)) {
|
||||||
// If we don't own the stream and it was passed to us, eval immediately
|
// If we don't own the stream and it was passed to us, eval immediately
|
||||||
auto arr = load_safetensors(std::make_shared<PyFileReader>(file), s);
|
auto res = load_safetensors(std::make_shared<PyFileReader>(file), s);
|
||||||
{
|
{
|
||||||
py::gil_scoped_release gil;
|
py::gil_scoped_release gil;
|
||||||
for (auto& [key, arr] : arr) {
|
for (auto& [key, arr] : std::get<0>(res)) {
|
||||||
arr.eval();
|
arr.eval();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return arr;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[load_safetensors] Input must be a file-like object, or string");
|
"[load_safetensors] Input must be a file-like object, or string");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<
|
GGUFLoad mlx_load_gguf_helper(py::object file, StreamOrDevice s) {
|
||||||
std::unordered_map<std::string, array>,
|
|
||||||
std::unordered_map<std::string, MetaData>>
|
|
||||||
mlx_load_gguf_helper(py::object file, StreamOrDevice s) {
|
|
||||||
if (py::isinstance<py::str>(file)) { // Assume .gguf file path string
|
if (py::isinstance<py::str>(file)) { // Assume .gguf file path string
|
||||||
return load_gguf(py::cast<std::string>(file), s);
|
return load_gguf(py::cast<std::string>(file), s);
|
||||||
}
|
}
|
||||||
@ -274,12 +272,16 @@ LoadOutputTypes mlx_load_helper(
|
|||||||
format.emplace(fname.substr(ext + 1));
|
format.emplace(fname.substr(ext + 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (return_metadata && format.value() != "gguf") {
|
if (return_metadata && (format.value() == "npy" || format.value() == "npz")) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[load] metadata not supported for format " + format.value());
|
"[load] metadata not supported for format " + format.value());
|
||||||
}
|
}
|
||||||
if (format.value() == "safetensors") {
|
if (format.value() == "safetensors") {
|
||||||
return mlx_load_safetensor_helper(file, s);
|
auto [dict, metadata] = mlx_load_safetensor_helper(file, s);
|
||||||
|
if (return_metadata) {
|
||||||
|
return std::make_pair(dict, metadata);
|
||||||
|
}
|
||||||
|
return dict;
|
||||||
} else if (format.value() == "npz") {
|
} else if (format.value() == "npz") {
|
||||||
return mlx_load_npz_helper(file, s);
|
return mlx_load_npz_helper(file, s);
|
||||||
} else if (format.value() == "npy") {
|
} else if (format.value() == "npy") {
|
||||||
@ -444,18 +446,33 @@ void mlx_savez_helper(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlx_save_safetensor_helper(py::object file, py::dict d) {
|
void mlx_save_safetensor_helper(
|
||||||
|
py::object file,
|
||||||
|
py::dict d,
|
||||||
|
std::optional<py::dict> m) {
|
||||||
|
std::unordered_map<std::string, std::string> metadata_map;
|
||||||
|
if (m) {
|
||||||
|
try {
|
||||||
|
metadata_map =
|
||||||
|
m.value().cast<std::unordered_map<std::string, std::string>>();
|
||||||
|
} catch (const py::cast_error& e) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[save_safetensors] Metadata must be a dictionary with string keys and values");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
metadata_map = std::unordered_map<std::string, std::string>();
|
||||||
|
}
|
||||||
auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
|
auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
|
||||||
if (py::isinstance<py::str>(file)) {
|
if (py::isinstance<py::str>(file)) {
|
||||||
{
|
{
|
||||||
py::gil_scoped_release nogil;
|
py::gil_scoped_release nogil;
|
||||||
save_safetensors(py::cast<std::string>(file), arrays_map);
|
save_safetensors(py::cast<std::string>(file), arrays_map, metadata_map);
|
||||||
}
|
}
|
||||||
} else if (is_ostream_object(file)) {
|
} else if (is_ostream_object(file)) {
|
||||||
auto writer = std::make_shared<PyFileWriter>(file);
|
auto writer = std::make_shared<PyFileWriter>(file);
|
||||||
{
|
{
|
||||||
py::gil_scoped_release nogil;
|
py::gil_scoped_release nogil;
|
||||||
save_safetensors(writer, arrays_map);
|
save_safetensors(writer, arrays_map, metadata_map);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
@ -471,7 +488,7 @@ void mlx_save_gguf_helper(
|
|||||||
if (py::isinstance<py::str>(file)) {
|
if (py::isinstance<py::str>(file)) {
|
||||||
if (m) {
|
if (m) {
|
||||||
auto metadata_map =
|
auto metadata_map =
|
||||||
m.value().cast<std::unordered_map<std::string, MetaData>>();
|
m.value().cast<std::unordered_map<std::string, GGUFMetaData>>();
|
||||||
{
|
{
|
||||||
py::gil_scoped_release nogil;
|
py::gil_scoped_release nogil;
|
||||||
save_gguf(py::cast<std::string>(file), arrays_map, metadata_map);
|
save_gguf(py::cast<std::string>(file), arrays_map, metadata_map);
|
||||||
|
@ -15,19 +15,17 @@ using namespace mlx::core;
|
|||||||
using LoadOutputTypes = std::variant<
|
using LoadOutputTypes = std::variant<
|
||||||
array,
|
array,
|
||||||
std::unordered_map<std::string, array>,
|
std::unordered_map<std::string, array>,
|
||||||
std::pair<
|
SafetensorsLoad,
|
||||||
std::unordered_map<std::string, array>,
|
GGUFLoad>;
|
||||||
std::unordered_map<std::string, MetaData>>>;
|
|
||||||
|
|
||||||
std::unordered_map<std::string, array> mlx_load_safetensor_helper(
|
SafetensorsLoad mlx_load_safetensor_helper(py::object file, StreamOrDevice s);
|
||||||
|
void mlx_save_safetensor_helper(
|
||||||
py::object file,
|
py::object file,
|
||||||
StreamOrDevice s);
|
py::dict d,
|
||||||
void mlx_save_safetensor_helper(py::object file, py::dict d);
|
std::optional<py::dict> m);
|
||||||
|
|
||||||
|
GGUFLoad mlx_load_gguf_helper(py::object file, StreamOrDevice s);
|
||||||
|
|
||||||
std::pair<
|
|
||||||
std::unordered_map<std::string, array>,
|
|
||||||
std::unordered_map<std::string, MetaData>>
|
|
||||||
mlx_load_gguf_helper(py::object file, StreamOrDevice s);
|
|
||||||
void mlx_save_gguf_helper(
|
void mlx_save_gguf_helper(
|
||||||
py::object file,
|
py::object file,
|
||||||
py::dict d,
|
py::dict d,
|
||||||
|
@ -3214,8 +3214,9 @@ void init_ops(py::module_& m) {
|
|||||||
&mlx_save_safetensor_helper,
|
&mlx_save_safetensor_helper,
|
||||||
"file"_a,
|
"file"_a,
|
||||||
"arrays"_a,
|
"arrays"_a,
|
||||||
|
"metadata"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
save_safetensors(file: str, arrays: Dict[str, array])
|
save_safetensors(file: str, arrays: Dict[str, array], metadata: Optional[Dict[str, str]] = None)
|
||||||
|
|
||||||
Save array(s) to a binary file in ``.safetensors`` format.
|
Save array(s) to a binary file in ``.safetensors`` format.
|
||||||
|
|
||||||
@ -3225,6 +3226,7 @@ void init_ops(py::module_& m) {
|
|||||||
Args:
|
Args:
|
||||||
file (file, str): File in which the array is saved.
|
file (file, str): File in which the array is saved.
|
||||||
arrays (dict(str, array)): The dictionary of names to arrays to be saved.
|
arrays (dict(str, array)): The dictionary of names to arrays to be saved.
|
||||||
|
metadata (dict(str, str), optional): The dictionary of metadata to be saved.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"save_gguf",
|
"save_gguf",
|
||||||
|
@ -66,6 +66,15 @@ class TestLoad(mlx_tests.MLXTestCase):
|
|||||||
def test_save_and_load_safetensors(self):
|
def test_save_and_load_safetensors(self):
|
||||||
if not os.path.isdir(self.test_dir):
|
if not os.path.isdir(self.test_dir):
|
||||||
os.mkdir(self.test_dir)
|
os.mkdir(self.test_dir)
|
||||||
|
with self.assertRaises(Exception):
|
||||||
|
mx.save_safetensors("test", {"a": mx.ones((4, 4))}, {"testing": 0})
|
||||||
|
|
||||||
|
mx.save_safetensors(
|
||||||
|
"test", {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"}
|
||||||
|
)
|
||||||
|
res = mx.load("test.safetensors", return_metadata=True)
|
||||||
|
self.assertEqual(len(res), 2)
|
||||||
|
self.assertEqual(res[1], {"testing": "test", "format": "mlx"})
|
||||||
|
|
||||||
for dt in self.dtypes + ["bfloat16"]:
|
for dt in self.dtypes + ["bfloat16"]:
|
||||||
with self.subTest(dtype=dt):
|
with self.subTest(dtype=dt):
|
||||||
|
@ -19,8 +19,14 @@ TEST_CASE("test save_safetensors") {
|
|||||||
auto map = std::unordered_map<std::string, array>();
|
auto map = std::unordered_map<std::string, array>();
|
||||||
map.insert({"test", array({1.0, 2.0, 3.0, 4.0})});
|
map.insert({"test", array({1.0, 2.0, 3.0, 4.0})});
|
||||||
map.insert({"test2", ones({2, 2})});
|
map.insert({"test2", ones({2, 2})});
|
||||||
save_safetensors(file_path, map);
|
auto _metadata = std::unordered_map<std::string, std::string>();
|
||||||
auto dict = load_safetensors(file_path);
|
_metadata.insert({"test", "test"});
|
||||||
|
_metadata.insert({"test2", "test2"});
|
||||||
|
save_safetensors(file_path, map, _metadata);
|
||||||
|
auto [dict, metadata] = load_safetensors(file_path);
|
||||||
|
|
||||||
|
CHECK_EQ(metadata, _metadata);
|
||||||
|
|
||||||
CHECK_EQ(dict.size(), 2);
|
CHECK_EQ(dict.size(), 2);
|
||||||
CHECK_EQ(dict.count("test"), 1);
|
CHECK_EQ(dict.count("test"), 1);
|
||||||
CHECK_EQ(dict.count("test2"), 1);
|
CHECK_EQ(dict.count("test2"), 1);
|
||||||
@ -55,7 +61,7 @@ TEST_CASE("test gguf") {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Test saving and loading string metadata
|
// Test saving and loading string metadata
|
||||||
std::unordered_map<std::string, MetaData> original_metadata;
|
std::unordered_map<std::string, GGUFMetaData> original_metadata;
|
||||||
original_metadata.insert({"test_str", "my string"});
|
original_metadata.insert({"test_str", "my string"});
|
||||||
|
|
||||||
save_gguf(file_path, original_weights, original_metadata);
|
save_gguf(file_path, original_weights, original_metadata);
|
||||||
@ -97,7 +103,7 @@ TEST_CASE("test gguf metadata") {
|
|||||||
|
|
||||||
// Scalar array
|
// Scalar array
|
||||||
{
|
{
|
||||||
std::unordered_map<std::string, MetaData> original_metadata;
|
std::unordered_map<std::string, GGUFMetaData> original_metadata;
|
||||||
original_metadata.insert({"test_arr", array(1.0)});
|
original_metadata.insert({"test_arr", array(1.0)});
|
||||||
save_gguf(file_path, original_weights, original_metadata);
|
save_gguf(file_path, original_weights, original_metadata);
|
||||||
|
|
||||||
@ -111,7 +117,7 @@ TEST_CASE("test gguf metadata") {
|
|||||||
|
|
||||||
// 1D Array
|
// 1D Array
|
||||||
{
|
{
|
||||||
std::unordered_map<std::string, MetaData> original_metadata;
|
std::unordered_map<std::string, GGUFMetaData> original_metadata;
|
||||||
auto arr = array({1.0, 2.0});
|
auto arr = array({1.0, 2.0});
|
||||||
original_metadata.insert({"test_arr", arr});
|
original_metadata.insert({"test_arr", arr});
|
||||||
save_gguf(file_path, original_weights, original_metadata);
|
save_gguf(file_path, original_weights, original_metadata);
|
||||||
@ -138,21 +144,21 @@ TEST_CASE("test gguf metadata") {
|
|||||||
|
|
||||||
// > 1D array throws
|
// > 1D array throws
|
||||||
{
|
{
|
||||||
std::unordered_map<std::string, MetaData> original_metadata;
|
std::unordered_map<std::string, GGUFMetaData> original_metadata;
|
||||||
original_metadata.insert({"test_arr", array({1.0}, {1, 1})});
|
original_metadata.insert({"test_arr", array({1.0}, {1, 1})});
|
||||||
CHECK_THROWS(save_gguf(file_path, original_weights, original_metadata));
|
CHECK_THROWS(save_gguf(file_path, original_weights, original_metadata));
|
||||||
}
|
}
|
||||||
|
|
||||||
// empty array throws
|
// empty array throws
|
||||||
{
|
{
|
||||||
std::unordered_map<std::string, MetaData> original_metadata;
|
std::unordered_map<std::string, GGUFMetaData> original_metadata;
|
||||||
original_metadata.insert({"test_arr", array({})});
|
original_metadata.insert({"test_arr", array({})});
|
||||||
CHECK_THROWS(save_gguf(file_path, original_weights, original_metadata));
|
CHECK_THROWS(save_gguf(file_path, original_weights, original_metadata));
|
||||||
}
|
}
|
||||||
|
|
||||||
// vector of string
|
// vector of string
|
||||||
{
|
{
|
||||||
std::unordered_map<std::string, MetaData> original_metadata;
|
std::unordered_map<std::string, GGUFMetaData> original_metadata;
|
||||||
std::vector<std::string> data = {"data1", "data2", "data1234"};
|
std::vector<std::string> data = {"data1", "data2", "data1234"};
|
||||||
original_metadata.insert({"meta", data});
|
original_metadata.insert({"meta", data});
|
||||||
save_gguf(file_path, original_weights, original_metadata);
|
save_gguf(file_path, original_weights, original_metadata);
|
||||||
@ -169,7 +175,7 @@ TEST_CASE("test gguf metadata") {
|
|||||||
|
|
||||||
// vector of string, string, scalar, and array
|
// vector of string, string, scalar, and array
|
||||||
{
|
{
|
||||||
std::unordered_map<std::string, MetaData> original_metadata;
|
std::unordered_map<std::string, GGUFMetaData> original_metadata;
|
||||||
std::vector<std::string> data = {"data1", "data2", "data1234"};
|
std::vector<std::string> data = {"data1", "data2", "data1234"};
|
||||||
original_metadata.insert({"meta1", data});
|
original_metadata.insert({"meta1", data});
|
||||||
original_metadata.insert({"meta2", array(2.5)});
|
original_metadata.insert({"meta2", array(2.5)});
|
||||||
|
Loading…
Reference in New Issue
Block a user