From c81f5a5b946b4188875ab45a762dca6c9543ab4b Mon Sep 17 00:00:00 2001 From: dc-dc-dc Date: Wed, 20 Dec 2023 12:16:29 -0500 Subject: [PATCH] added format mlx to metadata --- mlx/safetensor.cpp | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/mlx/safetensor.cpp b/mlx/safetensor.cpp index 4647fa37f..f511faf80 100644 --- a/mlx/safetensor.cpp +++ b/mlx/safetensor.cpp @@ -127,10 +127,19 @@ std::unordered_map load_safetensor( void save_safetensor( std::shared_ptr out_stream, std::unordered_map a) { + //////////////////////////////////////////////////////// + // Check file + if (!out_stream->good() || !out_stream->is_open()) { + throw std::runtime_error( + "[save_safetensor] Failed to open " + out_stream->label()); + } + //////////////////////////////////////////////////////// // Check array map json parent; - + parent["__metadata__"] = json::object({ + {"format", "mlx"}, + }); size_t offset = 0; for (auto& [key, arr] : a) { arr.eval(false); @@ -145,7 +154,6 @@ void save_safetensor( key); } json child; - // TODO: dont make a new string child["dtype"] = dtype_to_safetensor_str(arr.dtype()); child["shape"] = arr.shape(); child["data_offsets"] = std::vector{offset, offset + arr.nbytes()}; @@ -153,13 +161,6 @@ void save_safetensor( offset += arr.nbytes(); } - //////////////////////////////////////////////////////// - // Check file - if (!out_stream->good() || !out_stream->is_open()) { - throw std::runtime_error( - "[save_safetensor] Failed to open " + out_stream->label()); - } - auto header = parent.dump(); uint64_t header_len = header.length(); out_stream->write(reinterpret_cast(&header_len), 8);