more cleanup

This commit is contained in:
dc-dc-dc 2023-12-18 10:00:10 -05:00
parent d0285db98c
commit 91495382fd
3 changed files with 16 additions and 9 deletions

View File

@ -56,7 +56,6 @@ Token Tokenizer::getToken() {
JSONNode parseJson(const char* data, size_t len) {
auto tokenizer = Tokenizer(data, len);
std::stack<JSONNode*> ctx;
std::string key;
while (tokenizer.hasMoreTokens()) {
auto token = tokenizer.getToken();
switch (token.type) {
@ -81,7 +80,7 @@ JSONNode parseJson(const char* data, size_t len) {
if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
ctx.top()->getObject()->insert({key->getString(), obj});
} else {
throw new std::runtime_error("invalid json");
throw std::runtime_error("invalid json");
}
} else if (ctx.top()->is_type(JSONNode::Type::LIST)) {
auto list = ctx.top()->getList();
@ -89,7 +88,6 @@ JSONNode parseJson(const char* data, size_t len) {
}
}
break;
case TOKEN::ARRAY_CLOSE:
if (ctx.top()->is_type(JSONNode::Type::STRING)) {
// key is above
@ -115,7 +113,9 @@ JSONNode parseJson(const char* data, size_t len) {
case TOKEN::STRING: {
auto str =
new std::string(data + token.start, token.end - token.start - 1);
if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
if (ctx.top()->is_type(JSONNode::Type::LIST)) {
ctx.top()->getList()->push_back(new JSONNode(str));
} else if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
ctx.push(new JSONNode(str));
} else if (ctx.top()->is_type(JSONNode::Type::STRING)) {
auto key = ctx.top();
@ -124,10 +124,8 @@ JSONNode parseJson(const char* data, size_t len) {
ctx.top()->getObject()->insert(
{key->getString(), new JSONNode(str)});
} else {
throw new std::runtime_error("invalid json");
throw std::runtime_error("invalid json");
}
} else if (ctx.top()->is_type(JSONNode::Type::LIST)) {
ctx.top()->getList()->push_back(new JSONNode(str));
}
break;
}
@ -144,7 +142,7 @@ JSONNode parseJson(const char* data, size_t len) {
ctx.top()->getObject()->insert(
{key->getString(), new JSONNode(val)});
} else {
throw new std::runtime_error("invalid json");
throw std::runtime_error("invalid json");
}
}
break;
@ -157,6 +155,7 @@ JSONNode parseJson(const char* data, size_t len) {
break;
}
}
throw std::runtime_error("invalid json");
}
} // namespace io
@ -181,6 +180,11 @@ std::map<std::string, array> load_safetensor(
// Load the json metadata
char json[jsonHeaderLength];
in_stream->read(json, jsonHeaderLength);
auto metadata = io::parseJson(json, jsonHeaderLength);
if (!metadata.is_type(io::JSONNode::Type::OBJECT)) {
throw std::runtime_error(
"[load_safetensor] Invalid json metadata " + in_stream->label());
}
// Parse the json raw data
std::map<std::string, array> res;
return res;

View File

@ -18,7 +18,7 @@ using JSONList = std::vector<JSONNode*>;
class JSONNode {
public:
enum class Type { OBJECT, LIST, STRING, NUMBER, BOOLEAN, NULL_TYPE };
enum class Type { OBJECT, LIST, STRING, NUMBER, NULL_TYPE };
JSONNode() : _type(Type::NULL_TYPE){};
JSONNode(Type type) : _type(type) {

View File

@ -47,6 +47,9 @@ TEST_CASE("test parseJson") {
res = io::parseJson(raw.c_str(), raw.size());
CHECK(res.is_type(io::JSONNode::Type::LIST));
raw = std::string("[");
CHECK_THROWS_AS(io::parseJson(raw.c_str(), raw.size()), std::runtime_error);
raw = std::string("[{}, \"test\"]");
res = io::parseJson(raw.c_str(), raw.size());
CHECK(res.is_type(io::JSONNode::Type::LIST));