added format mlx to metadata

This commit is contained in:
dc-dc-dc 2023-12-20 12:16:29 -05:00
parent 29e43170c4
commit c81f5a5b94

View File

@ -127,10 +127,19 @@ std::unordered_map<std::string, array> load_safetensor(
void save_safetensor( void save_safetensor(
std::shared_ptr<io::Writer> out_stream, std::shared_ptr<io::Writer> out_stream,
std::unordered_map<std::string, array> a) { std::unordered_map<std::string, array> a) {
////////////////////////////////////////////////////////
// Check file
if (!out_stream->good() || !out_stream->is_open()) {
throw std::runtime_error(
"[save_safetensor] Failed to open " + out_stream->label());
}
//////////////////////////////////////////////////////// ////////////////////////////////////////////////////////
// Check array map // Check array map
json parent; json parent;
parent["__metadata__"] = json::object({
{"format", "mlx"},
});
size_t offset = 0; size_t offset = 0;
for (auto& [key, arr] : a) { for (auto& [key, arr] : a) {
arr.eval(false); arr.eval(false);
@ -145,7 +154,6 @@ void save_safetensor(
key); key);
} }
json child; json child;
// TODO: dont make a new string
child["dtype"] = dtype_to_safetensor_str(arr.dtype()); child["dtype"] = dtype_to_safetensor_str(arr.dtype());
child["shape"] = arr.shape(); child["shape"] = arr.shape();
child["data_offsets"] = std::vector<size_t>{offset, offset + arr.nbytes()}; child["data_offsets"] = std::vector<size_t>{offset, offset + arr.nbytes()};
@ -153,13 +161,6 @@ void save_safetensor(
offset += arr.nbytes(); offset += arr.nbytes();
} }
////////////////////////////////////////////////////////
// Check file
if (!out_stream->good() || !out_stream->is_open()) {
throw std::runtime_error(
"[save_safetensor] Failed to open " + out_stream->label());
}
auto header = parent.dump(); auto header = parent.dump();
uint64_t header_len = header.length(); uint64_t header_len = header.length();
out_stream->write(reinterpret_cast<char*>(&header_len), 8); out_stream->write(reinterpret_cast<char*>(&header_len), 8);