mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 04:56:41 +08:00
added format mlx to metadata
This commit is contained in:
parent
29e43170c4
commit
c81f5a5b94
@ -127,10 +127,19 @@ std::unordered_map<std::string, array> load_safetensor(
|
||||
void save_safetensor(
|
||||
std::shared_ptr<io::Writer> out_stream,
|
||||
std::unordered_map<std::string, array> a) {
|
||||
////////////////////////////////////////////////////////
|
||||
// Check file
|
||||
if (!out_stream->good() || !out_stream->is_open()) {
|
||||
throw std::runtime_error(
|
||||
"[save_safetensor] Failed to open " + out_stream->label());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////
|
||||
// Check array map
|
||||
json parent;
|
||||
|
||||
parent["__metadata__"] = json::object({
|
||||
{"format", "mlx"},
|
||||
});
|
||||
size_t offset = 0;
|
||||
for (auto& [key, arr] : a) {
|
||||
arr.eval(false);
|
||||
@ -145,7 +154,6 @@ void save_safetensor(
|
||||
key);
|
||||
}
|
||||
json child;
|
||||
// TODO: dont make a new string
|
||||
child["dtype"] = dtype_to_safetensor_str(arr.dtype());
|
||||
child["shape"] = arr.shape();
|
||||
child["data_offsets"] = std::vector<size_t>{offset, offset + arr.nbytes()};
|
||||
@ -153,13 +161,6 @@ void save_safetensor(
|
||||
offset += arr.nbytes();
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////
|
||||
// Check file
|
||||
if (!out_stream->good() || !out_stream->is_open()) {
|
||||
throw std::runtime_error(
|
||||
"[save_safetensor] Failed to open " + out_stream->label());
|
||||
}
|
||||
|
||||
auto header = parent.dump();
|
||||
uint64_t header_len = header.length();
|
||||
out_stream->write(reinterpret_cast<char*>(&header_len), 8);
|
||||
|
Loading…
Reference in New Issue
Block a user