diff --git a/.gitignore b/.gitignore index 8190c9557..8dfe5038e 100644 --- a/.gitignore +++ b/.gitignore @@ -8,7 +8,7 @@ __pycache__/ # tensor files *.safe -*.safetensor +*.safetensors # Metal libraries *.metallib diff --git a/mlx/ops.h b/mlx/ops.h index bd0521cc8..bd002b5df 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1065,4 +1065,10 @@ std::unordered_map load_safetensor( const std::string& file, StreamOrDevice s = {}); +void save_safetensor( + std::shared_ptr in_stream, + std::unordered_map); +void save_safetensor( + const std::string& file, + std::unordered_map); } // namespace mlx::core diff --git a/mlx/safetensor.cpp b/mlx/safetensor.cpp index 410255627..85b69ada6 100644 --- a/mlx/safetensor.cpp +++ b/mlx/safetensor.cpp @@ -55,7 +55,7 @@ Token Tokenizer::getToken() { } } -JSONNode parseJson(const char* data, size_t len) { +JSONNode jsonDeserialize(const char* data, size_t len) { auto tokenizer = Tokenizer(data, len); std::stack ctx; while (tokenizer.hasMoreTokens()) { @@ -137,7 +137,7 @@ JSONNode parseJson(const char* data, size_t len) { case TOKEN::NUMBER: { // TODO: is there an easier way of doing this. auto str = new std::string(data + token.start, token.end - token.start); - float val = strtof(str->c_str(), nullptr); + auto val = strtoul(str->c_str(), nullptr, 10); if (ctx.top()->is_type(JSONNode::Type::LIST)) { ctx.top()->getList()->push_back(new JSONNode(val)); } else if (ctx.top()->is_type(JSONNode::Type::STRING)) { @@ -152,21 +152,83 @@ JSONNode parseJson(const char* data, size_t len) { } break; } - case TOKEN::COMMA: - break; - case TOKEN::COLON: - break; - case TOKEN::NULL_TYPE: + default: break; } } throw std::runtime_error( - "[parseJson] json was invalid and could not be parsed"); + "[jsonDeserialize] json was invalid and could not be parsed"); +} + +std::string jsonSerialize(JSONNode* node) { + std::string res; + if (node->is_type(JSONNode::Type::STRING)) { + return "\"" + node->getString() + "\""; + } + if (node->is_type(JSONNode::Type::NUMBER)) { + return std::to_string(node->getNumber()); + } + if (node->is_type(JSONNode::Type::LIST)) { + res += "["; + for (auto& item : *node->getList()) { + res += jsonSerialize(item); + res += ","; + } + if (res.back() == ',') { + res.pop_back(); + } + res += "]"; + return res; + } + if (node->is_type(JSONNode::Type::OBJECT)) { + res += "{"; + for (auto& [key, item] : *node->getObject()) { + res += "\"" + key + "\":"; + res += jsonSerialize(item); + res += ","; + } + if (res.back() == ',') { + res.pop_back(); + } + res += "}"; + return res; + } + + throw std::runtime_error("[jsonSerialize] invalid json node"); } } // namespace io +std::string dtype_to_safetensor_str(Dtype t) { + if (t == float32) { + return ST_F32; + } else if (t == bfloat16) { + return ST_BF16; + } else if (t == float16) { + return ST_F16; + } else if (t == int64) { + return ST_I64; + } else if (t == int32) { + return ST_I32; + } else if (t == int16) { + return ST_I16; + } else if (t == int8) { + return ST_I8; + } else if (t == uint64) { + return ST_U64; + } else if (t == uint32) { + return ST_U32; + } else if (t == uint16) { + return ST_U16; + } else if (t == uint8) { + return ST_U8; + } else if (t == bool_) { + return ST_BOOL; + } else { + throw std::runtime_error("[safetensor] unsupported dtype"); + } +} -Dtype dtype_from_safe_tensor_str(std::string str) { +Dtype dtype_from_safetensor_str(std::string str) { if (str == ST_F32) { return float32; } else if (str == ST_F16) { @@ -216,7 +278,7 @@ std::unordered_map load_safetensor( // Load the json metadata char json[jsonHeaderLength]; in_stream->read(json, jsonHeaderLength); - auto metadata = io::parseJson(json, jsonHeaderLength); + auto metadata = io::jsonDeserialize(json, jsonHeaderLength); // Should always be an object on the top-level if (!metadata.is_type(io::JSONNode::Type::OBJECT)) { throw std::runtime_error( @@ -237,7 +299,7 @@ std::unordered_map load_safetensor( for (const auto& offset : *data_offsets) { data_offsets_vec.push_back(offset->getNumber()); } - Dtype type = dtype_from_safe_tensor_str(dtype); + Dtype type = dtype_from_safetensor_str(dtype); auto loaded_array = array( shape_vec, float32, @@ -259,4 +321,82 @@ std::unordered_map load_safetensor( return load_safetensor(std::make_shared(file), s); } +/** Save array to out stream in .npy format */ +void save_safetensor( + std::shared_ptr out_stream, + std::unordered_map a) { + //////////////////////////////////////////////////////// + // Check array map + + io::JSONNode metadata(io::JSONNode::Type::OBJECT); + size_t offset = 0; + for (auto& [key, arr] : a) { + arr.eval(false); + if (arr.nbytes() == 0) { + throw std::invalid_argument( + "[save_safetensor] cannot serialize an empty array key: " + key); + } + + if (!arr.flags().contiguous) { + throw std::invalid_argument( + "[save_safetensor] cannot serialize a non-contiguous array key: " + + key); + } + auto obj = new io::JSONNode(io::JSONNode::Type::OBJECT); + // TODO: dont make a new string + obj->getObject()->insert( + {"dtype", + new io::JSONNode( + new std::string(dtype_to_safetensor_str(arr.dtype())))}); + obj->getObject()->insert( + {"shape", new io::JSONNode(io::JSONNode::Type::LIST)}); + for (auto& dim : arr.shape()) { + obj->getObject()->at("shape")->getList()->push_back( + new io::JSONNode(dim)); + } + obj->getObject()->insert( + {"data_offsets", new io::JSONNode(io::JSONNode::Type::LIST)}); + obj->getObject() + ->at("data_offsets") + ->getList() + ->push_back(new io::JSONNode(offset)); + obj->getObject() + ->at("data_offsets") + ->getList() + ->push_back(new io::JSONNode(offset + arr.nbytes())); + metadata.getObject()->insert({key, obj}); + offset += arr.nbytes(); + } + + //////////////////////////////////////////////////////// + // Check file + if (!out_stream->good() || !out_stream->is_open()) { + throw std::runtime_error( + "[save_safetensor] Failed to open " + out_stream->label()); + } + + auto header = io::jsonSerialize(&metadata); + uint64_t header_len = header.length(); + out_stream->write(reinterpret_cast(&header_len), 8); + out_stream->write(header.c_str(), header_len); + for (auto& [key, arr] : a) { + out_stream->write(arr.data(), arr.nbytes()); + } +} + +void save_safetensor( + const std::string& file_, + std::unordered_map a) { + // Open and check file + std::string file = file_; + + // Add .npy to file name if it is not there + if (file.length() < 12 || + file.substr(file.length() - 12, 12) != ".safetensors") + file += ".safetensors"; + + // Serialize array + save_safetensor(std::make_shared(file), a); +} + } // namespace mlx::core \ No newline at end of file diff --git a/mlx/safetensor.h b/mlx/safetensor.h index 662c4ee00..19aa26aa2 100644 --- a/mlx/safetensor.h +++ b/mlx/safetensor.h @@ -74,7 +74,7 @@ class JSONNode { return *this->_values.s; } - float getNumber() { + uint32_t getNumber() { if (!is_type(Type::NUMBER)) { throw new std::runtime_error("not a number"); } @@ -94,12 +94,13 @@ class JSONNode { JSONObject* object; JSONList* list; std::string* s; - float f; + uint32_t f; } _values; Type _type; }; -JSONNode parseJson(const char* data, size_t len); +JSONNode jsonDeserialize(const char* data, size_t len); +std::string jsonSerialize(JSONNode* node); enum class TOKEN { CURLY_OPEN, diff --git a/tests/load_tests.cpp b/tests/load_tests.cpp index c4fc554b2..49f94eb1d 100644 --- a/tests/load_tests.cpp +++ b/tests/load_tests.cpp @@ -38,20 +38,42 @@ TEST_CASE("test tokenizer") { CHECK_EQ(tokenizer.getToken().type, io::TOKEN::NULL_TYPE); } -TEST_CASE("test parseJson") { +TEST_CASE("test jsonSerialize") { + auto test = new io::JSONNode(io::JSONNode::Type::OBJECT); + auto src = io::jsonSerialize(test); + CHECK_EQ(src, "{}"); + test = new io::JSONNode(io::JSONNode::Type::LIST); + src = io::jsonSerialize(test); + CHECK_EQ(src, "[]"); + test = new io::JSONNode(io::JSONNode::Type::OBJECT); + test->getObject()->insert( + {"test", new io::JSONNode(new std::string("testing"))}); + src = io::jsonSerialize(test); + CHECK_EQ(src, "{\"test\":\"testing\"}"); + test = new io::JSONNode(io::JSONNode::Type::OBJECT); + auto arr = new io::JSONNode(io::JSONNode::Type::LIST); + arr->getList()->push_back(new io::JSONNode(1)); + arr->getList()->push_back(new io::JSONNode(2)); + test->getObject()->insert({"test", arr}); + src = io::jsonSerialize(test); + CHECK_EQ(src, "{\"test\":[1,2]}"); +} + +TEST_CASE("test jsonDeserialize") { auto raw = std::string("{}"); - auto res = io::parseJson(raw.c_str(), raw.size()); + auto res = io::jsonDeserialize(raw.c_str(), raw.size()); CHECK(res.is_type(io::JSONNode::Type::OBJECT)); raw = std::string("[]"); - res = io::parseJson(raw.c_str(), raw.size()); + res = io::jsonDeserialize(raw.c_str(), raw.size()); CHECK(res.is_type(io::JSONNode::Type::LIST)); raw = std::string("["); - CHECK_THROWS_AS(io::parseJson(raw.c_str(), raw.size()), std::runtime_error); + CHECK_THROWS_AS( + io::jsonDeserialize(raw.c_str(), raw.size()), std::runtime_error); raw = std::string("[{}, \"test\"]"); - res = io::parseJson(raw.c_str(), raw.size()); + res = io::jsonDeserialize(raw.c_str(), raw.size()); CHECK(res.is_type(io::JSONNode::Type::LIST)); CHECK_EQ(res.getList()->size(), 2); CHECK(res.getList()->at(0)->is_type(io::JSONNode::Type::OBJECT)); @@ -60,7 +82,7 @@ TEST_CASE("test parseJson") { raw = std::string( "{\"test\":{\"dtype\":\"F32\",\"shape\":[4], \"data_offsets\":[0, 16]}}"); - res = io::parseJson(raw.c_str(), raw.size()); + res = io::jsonDeserialize(raw.c_str(), raw.size()); CHECK(res.is_type(io::JSONNode::Type::OBJECT)); CHECK_EQ(res.getObject()->size(), 1); CHECK(res.getObject()->at("test")->is_type(io::JSONNode::Type::OBJECT)); @@ -138,8 +160,16 @@ TEST_CASE("test parseJson") { 16); } +TEST_CASE("test save_safetensor") { + auto map = std::unordered_map(); + map.insert({"test", array({1.0, 2.0, 3.0, 4.0})}); + map.insert({"test2", ones({2, 2})}); + MESSAGE("SAVING"); + save_safetensor("../../temp1", map); +} + TEST_CASE("test load_safetensor") { - auto safeDict = load_safetensor("../../temp.safe"); + auto safeDict = load_safetensor("../../temp1.safetensors"); CHECK_EQ(safeDict.size(), 2); CHECK_EQ(safeDict.count("test"), 1); CHECK_EQ(safeDict.count("test2"), 1);