mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 21:16:47 +08:00
added format mlx to metadata
This commit is contained in:
parent
29e43170c4
commit
c81f5a5b94
@ -127,10 +127,19 @@ std::unordered_map<std::string, array> load_safetensor(
|
|||||||
void save_safetensor(
|
void save_safetensor(
|
||||||
std::shared_ptr<io::Writer> out_stream,
|
std::shared_ptr<io::Writer> out_stream,
|
||||||
std::unordered_map<std::string, array> a) {
|
std::unordered_map<std::string, array> a) {
|
||||||
|
////////////////////////////////////////////////////////
|
||||||
|
// Check file
|
||||||
|
if (!out_stream->good() || !out_stream->is_open()) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[save_safetensor] Failed to open " + out_stream->label());
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////
|
||||||
// Check array map
|
// Check array map
|
||||||
json parent;
|
json parent;
|
||||||
|
parent["__metadata__"] = json::object({
|
||||||
|
{"format", "mlx"},
|
||||||
|
});
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
for (auto& [key, arr] : a) {
|
for (auto& [key, arr] : a) {
|
||||||
arr.eval(false);
|
arr.eval(false);
|
||||||
@ -145,7 +154,6 @@ void save_safetensor(
|
|||||||
key);
|
key);
|
||||||
}
|
}
|
||||||
json child;
|
json child;
|
||||||
// TODO: dont make a new string
|
|
||||||
child["dtype"] = dtype_to_safetensor_str(arr.dtype());
|
child["dtype"] = dtype_to_safetensor_str(arr.dtype());
|
||||||
child["shape"] = arr.shape();
|
child["shape"] = arr.shape();
|
||||||
child["data_offsets"] = std::vector<size_t>{offset, offset + arr.nbytes()};
|
child["data_offsets"] = std::vector<size_t>{offset, offset + arr.nbytes()};
|
||||||
@ -153,13 +161,6 @@ void save_safetensor(
|
|||||||
offset += arr.nbytes();
|
offset += arr.nbytes();
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////
|
|
||||||
// Check file
|
|
||||||
if (!out_stream->good() || !out_stream->is_open()) {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"[save_safetensor] Failed to open " + out_stream->label());
|
|
||||||
}
|
|
||||||
|
|
||||||
auto header = parent.dump();
|
auto header = parent.dump();
|
||||||
uint64_t header_len = header.length();
|
uint64_t header_len = header.length();
|
||||||
out_stream->write(reinterpret_cast<char*>(&header_len), 8);
|
out_stream->write(reinterpret_cast<char*>(&header_len), 8);
|
||||||
|
Loading…
Reference in New Issue
Block a user