diff --git a/mlx/safetensor.cpp b/mlx/safetensor.cpp index 318a3d082..75679a4fc 100644 --- a/mlx/safetensor.cpp +++ b/mlx/safetensor.cpp @@ -56,7 +56,6 @@ Token Tokenizer::getToken() { JSONNode parseJson(const char* data, size_t len) { auto tokenizer = Tokenizer(data, len); std::stack ctx; - std::string key; while (tokenizer.hasMoreTokens()) { auto token = tokenizer.getToken(); switch (token.type) { @@ -81,7 +80,7 @@ JSONNode parseJson(const char* data, size_t len) { if (ctx.top()->is_type(JSONNode::Type::OBJECT)) { ctx.top()->getObject()->insert({key->getString(), obj}); } else { - throw new std::runtime_error("invalid json"); + throw std::runtime_error("invalid json"); } } else if (ctx.top()->is_type(JSONNode::Type::LIST)) { auto list = ctx.top()->getList(); @@ -89,7 +88,6 @@ JSONNode parseJson(const char* data, size_t len) { } } break; - case TOKEN::ARRAY_CLOSE: if (ctx.top()->is_type(JSONNode::Type::STRING)) { // key is above @@ -115,7 +113,9 @@ JSONNode parseJson(const char* data, size_t len) { case TOKEN::STRING: { auto str = new std::string(data + token.start, token.end - token.start - 1); - if (ctx.top()->is_type(JSONNode::Type::OBJECT)) { + if (ctx.top()->is_type(JSONNode::Type::LIST)) { + ctx.top()->getList()->push_back(new JSONNode(str)); + } else if (ctx.top()->is_type(JSONNode::Type::OBJECT)) { ctx.push(new JSONNode(str)); } else if (ctx.top()->is_type(JSONNode::Type::STRING)) { auto key = ctx.top(); @@ -124,10 +124,8 @@ JSONNode parseJson(const char* data, size_t len) { ctx.top()->getObject()->insert( {key->getString(), new JSONNode(str)}); } else { - throw new std::runtime_error("invalid json"); + throw std::runtime_error("invalid json"); } - } else if (ctx.top()->is_type(JSONNode::Type::LIST)) { - ctx.top()->getList()->push_back(new JSONNode(str)); } break; } @@ -144,7 +142,7 @@ JSONNode parseJson(const char* data, size_t len) { ctx.top()->getObject()->insert( {key->getString(), new JSONNode(val)}); } else { - throw new std::runtime_error("invalid json"); + throw std::runtime_error("invalid json"); } } break; @@ -157,6 +155,7 @@ JSONNode parseJson(const char* data, size_t len) { break; } } + throw std::runtime_error("invalid json"); } } // namespace io @@ -181,6 +180,11 @@ std::map load_safetensor( // Load the json metadata char json[jsonHeaderLength]; in_stream->read(json, jsonHeaderLength); + auto metadata = io::parseJson(json, jsonHeaderLength); + if (!metadata.is_type(io::JSONNode::Type::OBJECT)) { + throw std::runtime_error( + "[load_safetensor] Invalid json metadata " + in_stream->label()); + } // Parse the json raw data std::map res; return res; diff --git a/mlx/safetensor.h b/mlx/safetensor.h index 2fb5957c0..941b210c5 100644 --- a/mlx/safetensor.h +++ b/mlx/safetensor.h @@ -18,7 +18,7 @@ using JSONList = std::vector; class JSONNode { public: - enum class Type { OBJECT, LIST, STRING, NUMBER, BOOLEAN, NULL_TYPE }; + enum class Type { OBJECT, LIST, STRING, NUMBER, NULL_TYPE }; JSONNode() : _type(Type::NULL_TYPE){}; JSONNode(Type type) : _type(type) { diff --git a/tests/load_tests.cpp b/tests/load_tests.cpp index 73fa905dd..5eae0276a 100644 --- a/tests/load_tests.cpp +++ b/tests/load_tests.cpp @@ -47,6 +47,9 @@ TEST_CASE("test parseJson") { res = io::parseJson(raw.c_str(), raw.size()); CHECK(res.is_type(io::JSONNode::Type::LIST)); + raw = std::string("["); + CHECK_THROWS_AS(io::parseJson(raw.c_str(), raw.size()), std::runtime_error); + raw = std::string("[{}, \"test\"]"); res = io::parseJson(raw.c_str(), raw.size()); CHECK(res.is_type(io::JSONNode::Type::LIST));