diff --git a/mlx/safetensor.cpp b/mlx/safetensor.cpp index 0e4838fa0..ae1139d7b 100644 --- a/mlx/safetensor.cpp +++ b/mlx/safetensor.cpp @@ -36,17 +36,14 @@ Token Tokenizer::getToken() { this->_loc++; return Token{TOKEN::COMMA}; case '"': { - size_t start = this->_loc; - this->_loc++; - while (_data[this->_loc] != '"' && this->hasMoreTokens()) { - this->_loc++; - } + size_t start = ++this->_loc; + while (_data[++this->_loc] != '"' && this->hasMoreTokens()) + ; if (!this->hasMoreTokens()) { throw new std::runtime_error("no more chars to parse"); } // pass the last " - this->_loc++; - return Token{TOKEN::STRING, start, this->_loc}; + return Token{TOKEN::STRING, start, ++this->_loc}; } default: { size_t start = this->_loc; @@ -63,66 +60,105 @@ Token Tokenizer::getToken() { } } -// JSONNode parseJson(char* data, size_t len) { -// auto tokenizer = Tokenizer(data, len); -// std::stack ctx; -// auto token = tokenizer.getToken(); -// auto parent = new JSONNode(); +JSONNode parseJson(const char* data, size_t len) { + auto tokenizer = Tokenizer(data, len); + std::stack ctx; + while (tokenizer.hasMoreTokens()) { + auto token = tokenizer.getToken(); + switch (token.type) { + case TOKEN::NULL_TYPE: + break; + case TOKEN::CURLY_OPEN: + ctx.push(new JSONNode(JSONNode::Type::OBJECT)); + break; + case TOKEN::CURLY_CLOSE: + if (ctx.top()->is_type(JSONNode::Type::OBJECT)) { + auto obj = ctx.top(); + ctx.pop(); + // top-level object + if (ctx.size() == 0) { + return *obj; + } -// switch (token.type) { -// case TOKEN::CURLY_OPEN: -// parent->setObject(new JSONObject()); -// break; -// case TOKEN::ARRAY_OPEN: -// parent->setList(new JSONList()); -// break; -// default: -// throw new std::runtime_error("invalid json"); -// } -// ctx.push(parent); + if (ctx.top()->is_type(JSONNode::Type::STRING)) { + auto key = ctx.top(); + ctx.pop(); + if (ctx.top()->is_type(JSONNode::Type::OBJECT)) { + ctx.top()->getObject()->insert({key->getString(), obj}); + } else { + throw new std::runtime_error("invalid json"); + } + } else if (ctx.top()->is_type(JSONNode::Type::LIST)) { + auto list = ctx.top()->getList(); + list->push_back(obj); + } + } + break; + case TOKEN::ARRAY_OPEN: + ctx.push(new JSONNode(JSONNode::Type::LIST)); + break; + case TOKEN::ARRAY_CLOSE: + if (ctx.top()->is_type(JSONNode::Type::STRING)) { + // key is above + auto key = ctx.top(); + ctx.pop(); + if (ctx.top()->is_type(JSONNode::Type::OBJECT)) { + ctx.top()->getObject()->insert({key->getString(), new JSONNode()}); + } + } else if (ctx.top()->is_type(JSONNode::Type::LIST)) { + auto obj = ctx.top(); + ctx.pop(); + // top-level object + if (ctx.size() == 0) { + return *obj; + } -// while (tokenizer.hasMoreTokens()) { -// auto token = tokenizer.getToken(); -// switch (token.type) { -// case TOKEN::CURLY_OPEN: -// ctx.push(new JSONNode(JSONNode::Type::OBJECT)); -// break; -// case TOKEN::CURLY_CLOSE: -// if (ctx.top()->is_type(JSONNode::Type::OBJECT)) { -// auto obj = ctx.top(); -// ctx.pop(); -// if (ctx.top()->is_type(JSONNode::Type::LIST)) { -// auto list = ctx.top()->getList(); -// list->push_back(obj); -// } else if (ctx.top()->is_type(JSONNode::Type::STRING)) { -// // -// auto key = ctx.top(); -// ctx.pop(); -// if (ctx.top()->is_type(JSONNode::Type::OBJECT)) { -// ctx.top()->getObject()->insert({key->getString(), obj}); -// } -// } -// } else { -// throw new std::runtime_error("invalid json"); -// } -// break; -// case TOKEN::COLON: -// break; -// case TOKEN::ARRAY_OPEN: -// break; -// case TOKEN::ARRAY_CLOSE: -// break; -// case TOKEN::COMMA: -// break; -// case TOKEN::NULL_TYPE: -// break; -// case TOKEN::STRING: -// break; -// case TOKEN::NUMBER: -// break; -// } -// } -// } + if (ctx.top()->is_type(JSONNode::Type::LIST)) { + auto list = ctx.top()->getList(); + list->push_back(obj); + } + } + break; + case TOKEN::STRING: { + auto str = + new std::string(data + token.start, token.end - token.start - 1); + if (ctx.top()->is_type(JSONNode::Type::OBJECT)) { + ctx.push(new JSONNode(str)); + } else if (ctx.top()->is_type(JSONNode::Type::STRING)) { + auto key = ctx.top(); + ctx.pop(); + if (ctx.top()->is_type(JSONNode::Type::OBJECT)) { + ctx.top()->getObject()->insert( + {key->getString(), new JSONNode(str)}); + } else { + throw new std::runtime_error("invalid json"); + } + } else if (ctx.top()->is_type(JSONNode::Type::LIST)) { + ctx.top()->getList()->push_back(new JSONNode(str)); + } + break; + } + case TOKEN::NUMBER: + if (ctx.top()->is_type(JSONNode::Type::LIST)) { + ctx.top()->getList()->push_back(new JSONNode(JSONNode::Type::NUMBER)); + } else if (ctx.top()->is_type(JSONNode::Type::STRING)) { + auto key = ctx.top(); + ctx.pop(); + if (ctx.top()->is_type(JSONNode::Type::OBJECT)) { + ctx.top()->getObject()->insert( + {key->getString(), new JSONNode(JSONNode::Type::NUMBER)}); + } else { + throw new std::runtime_error("invalid json"); + } + } + break; + case TOKEN::COMMA: + break; + case TOKEN::COLON: + break; + } + } +} } // namespace io @@ -141,7 +177,7 @@ std::map load_safetensor( in_stream->read(reinterpret_cast(&jsonHeaderLength), 8); if (jsonHeaderLength <= 0) { throw std::runtime_error( - "[load_safetensor] Invalid json header lenght " + in_stream->label()); + "[load_safetensor] Invalid json header length " + in_stream->label()); } // Load the json metadata char json[jsonHeaderLength]; diff --git a/mlx/safetensor.h b/mlx/safetensor.h index 097711816..272d29979 100644 --- a/mlx/safetensor.h +++ b/mlx/safetensor.h @@ -13,8 +13,8 @@ namespace mlx::core { namespace io { class JSONNode; -using JSONObject = std::map>; -using JSONList = std::vector>; +using JSONObject = std::map; +using JSONList = std::vector; class JSONNode { public: @@ -29,11 +29,39 @@ class JSONNode { this->_values.list = new JSONList(); } }; + JSONNode(std::string* s) : _type(Type::STRING) { + this->_values.s = s; + }; + + JSONObject* getObject() { + if (!is_type(Type::OBJECT)) { + throw new std::runtime_error("not an object"); + } + return this->_values.object; + } + + JSONList* getList() { + if (!is_type(Type::LIST)) { + throw new std::runtime_error("not a list"); + } + return this->_values.list; + } + + std::string getString() { + if (!is_type(Type::STRING)) { + throw new std::runtime_error("not a string"); + } + return *this->_values.s; + } inline bool is_type(Type t) { return this->_type == t; } + inline Type type() const { + return this->_type; + } + private: union Values { JSONObject* object; @@ -44,6 +72,8 @@ class JSONNode { Type _type; }; +JSONNode parseJson(const char* data, size_t len); + enum class TOKEN { CURLY_OPEN, CURLY_CLOSE, diff --git a/tests/load_tests.cpp b/tests/load_tests.cpp index 471313aa0..70cddcc6e 100644 --- a/tests/load_tests.cpp +++ b/tests/load_tests.cpp @@ -27,11 +27,48 @@ TEST_CASE("test tokenizer") { CHECK_EQ(tokenizer.getToken().type, io::TOKEN::ARRAY_CLOSE); CHECK_EQ(tokenizer.getToken().type, io::TOKEN::CURLY_CLOSE); CHECK_EQ(tokenizer.getToken().type, io::TOKEN::NULL_TYPE); + + raw = std::string(" { \"testing\": \"test\"} "); + tokenizer = io::Tokenizer(raw.c_str(), raw.size()); + CHECK_EQ(tokenizer.getToken().type, io::TOKEN::CURLY_OPEN); + CHECK_EQ(tokenizer.getToken().type, io::TOKEN::STRING); + CHECK_EQ(tokenizer.getToken().type, io::TOKEN::COLON); + CHECK_EQ(tokenizer.getToken().type, io::TOKEN::STRING); + CHECK_EQ(tokenizer.getToken().type, io::TOKEN::CURLY_CLOSE); + CHECK_EQ(tokenizer.getToken().type, io::TOKEN::NULL_TYPE); } -// TEST_CASE("test load_safetensor") { -// auto array = load_safetensor("../../temp.safe"); -// } +TEST_CASE("test load_safetensor") { + auto raw = std::string("{}"); + auto res = io::parseJson(raw.c_str(), raw.size()); + CHECK(res.is_type(io::JSONNode::Type::OBJECT)); + raw = std::string("[]"); + res = io::parseJson(raw.c_str(), raw.size()); + CHECK(res.is_type(io::JSONNode::Type::LIST)); + raw = std::string("[{}, \"test\"]"); + res = io::parseJson(raw.c_str(), raw.size()); + CHECK(res.is_type(io::JSONNode::Type::LIST)); + CHECK_EQ(res.getList()->size(), 2); + CHECK(res.getList()->at(0)->is_type(io::JSONNode::Type::OBJECT)); + CHECK(res.getList()->at(1)->is_type(io::JSONNode::Type::STRING)); + MESSAGE(res.getList()->at(1)->getString()); + CHECK_EQ(res.getList()->at(1)->getString(), "test"); + raw = std::string("{\"test\": \"test\", \"test_num\": 1}"); + res = io::parseJson(raw.c_str(), raw.size()); + CHECK(res.is_type(io::JSONNode::Type::OBJECT)); + CHECK_EQ(res.getObject()->size(), 2); + CHECK(res.getObject()->at("test")->is_type(io::JSONNode::Type::STRING)); + CHECK_EQ(res.getObject()->at("test")->getString(), "test"); + CHECK(res.getObject()->at("test_num")->is_type(io::JSONNode::Type::NUMBER)); + raw = std::string("{\"test\": { \"test\": \"test\"}}"); + res = io::parseJson(raw.c_str(), raw.size()); + CHECK(res.is_type(io::JSONNode::Type::OBJECT)); + CHECK_EQ(res.getObject()->size(), 1); + CHECK(res.getObject()->at("test")->is_type(io::JSONNode::Type::OBJECT)); + CHECK_EQ(res.getObject()->at("test")->getObject()->size(), 1); + CHECK(res.getObject()->at("test")->getObject()->at("test")->is_type( + io::JSONNode::Type::STRING)); +} TEST_CASE("test single array serialization") { // Basic test