fixed array parsing

This commit is contained in:
dc-dc-dc 2023-12-18 10:28:41 -05:00
parent 91495382fd
commit fef579cec1
2 changed files with 136 additions and 17 deletions

View File

@ -83,32 +83,36 @@ JSONNode parseJson(const char* data, size_t len) {
throw std::runtime_error("invalid json");
}
} else if (ctx.top()->is_type(JSONNode::Type::LIST)) {
auto list = ctx.top()->getList();
list->push_back(obj);
ctx.top()->getList()->push_back(obj);
}
}
break;
case TOKEN::ARRAY_CLOSE:
if (ctx.top()->is_type(JSONNode::Type::LIST)) {
auto obj = ctx.top();
ctx.pop();
if (ctx.size() == 0) {
return *obj;
}
if (ctx.top()->is_type(JSONNode::Type::STRING)) {
// key is above
auto key = ctx.top();
ctx.pop();
if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
ctx.top()->getObject()->insert({key->getString(), new JSONNode()});
ctx.top()->getObject()->insert({key->getString(), obj});
} else {
throw std::runtime_error(
"invalid json, string/array key pair did not have object parent");
}
} else if (ctx.top()->is_type(JSONNode::Type::LIST)) {
auto obj = ctx.top();
ctx.pop();
// top-level object
if (ctx.size() == 0) {
return *obj;
}
if (ctx.top()->is_type(JSONNode::Type::LIST)) {
auto list = ctx.top()->getList();
list->push_back(obj);
ctx.top()->getList()->push_back(obj);
}
}
} else {
throw std::runtime_error(
"invalid json, could not find array to close");
}
break;
case TOKEN::STRING: {
auto str =
@ -155,7 +159,7 @@ JSONNode parseJson(const char* data, size_t len) {
break;
}
}
throw std::runtime_error("invalid json");
throw std::runtime_error("[unreachable] invalid json");
}
} // namespace io
@ -187,6 +191,22 @@ std::map<std::string, array> load_safetensor(
}
// Parse the json raw data
std::map<std::string, array> res;
for (const auto& key : *metadata.getObject()) {
std::string dtype = key.second->getObject()->at("dtype")->getString();
auto shape = key.second->getObject()->at("shape")->getList();
std::vector<int> shape_vec;
for (const auto& dim : *shape) {
shape_vec.push_back(dim->getNumber());
}
auto data_offsets = key.second->getObject()->at("data_offsets")->getList();
std::vector<int64_t> data_offsets_vec;
for (const auto& offset : *data_offsets) {
data_offsets_vec.push_back(offset->getNumber());
}
if (dtype == "F32") {
res.insert({key.first, zeros(shape_vec, s)});
}
}
return res;
}

View File

@ -56,7 +56,6 @@ TEST_CASE("test parseJson") {
CHECK_EQ(res.getList()->size(), 2);
CHECK(res.getList()->at(0)->is_type(io::JSONNode::Type::OBJECT));
CHECK(res.getList()->at(1)->is_type(io::JSONNode::Type::STRING));
MESSAGE(res.getList()->at(1)->getString());
CHECK_EQ(res.getList()->at(1)->getString(), "test");
raw = std::string("{\"test\": \"test\", \"test_num\": 1}");
@ -76,6 +75,106 @@ TEST_CASE("test parseJson") {
CHECK_EQ(res.getObject()->at("test")->getObject()->size(), 1);
CHECK(res.getObject()->at("test")->getObject()->at("test")->is_type(
io::JSONNode::Type::STRING));
raw = std::string("{\"test\":[1, 2]}");
res = io::parseJson(raw.c_str(), raw.size());
CHECK(res.is_type(io::JSONNode::Type::OBJECT));
CHECK_EQ(res.getObject()->size(), 1);
CHECK(res.getObject()->at("test")->is_type(io::JSONNode::Type::LIST));
CHECK_EQ(res.getObject()->at("test")->getList()->size(), 2);
CHECK(res.getObject()->at("test")->getList()->at(0)->is_type(
io::JSONNode::Type::NUMBER));
CHECK_EQ(res.getObject()->at("test")->getList()->at(0)->getNumber(), 1);
CHECK(res.getObject()->at("test")->getList()->at(1)->is_type(
io::JSONNode::Type::NUMBER));
CHECK_EQ(res.getObject()->at("test")->getList()->at(1)->getNumber(), 2);
raw = std::string(
"{\"test\":{\"dtype\":\"F32\",\"shape\":[4], \"data_offsets\":[0, 16]}}");
res = io::parseJson(raw.c_str(), raw.size());
CHECK(res.is_type(io::JSONNode::Type::OBJECT));
CHECK_EQ(res.getObject()->size(), 1);
CHECK(res.getObject()->at("test")->is_type(io::JSONNode::Type::OBJECT));
CHECK_EQ(res.getObject()->at("test")->getObject()->size(), 3);
CHECK(res.getObject()->at("test")->getObject()->at("dtype")->is_type(
io::JSONNode::Type::STRING));
CHECK_EQ(
res.getObject()->at("test")->getObject()->at("dtype")->getString(),
"F32");
CHECK(res.getObject()->at("test")->getObject()->at("shape")->is_type(
io::JSONNode::Type::LIST));
CHECK_EQ(
res.getObject()->at("test")->getObject()->at("shape")->getList()->size(),
1);
CHECK(res.getObject()
->at("test")
->getObject()
->at("shape")
->getList()
->at(0)
->is_type(io::JSONNode::Type::NUMBER));
CHECK_EQ(
res.getObject()
->at("test")
->getObject()
->at("shape")
->getList()
->at(0)
->getNumber(),
4);
CHECK(res.getObject()
->at("test")
->getObject()
->at("data_offsets")
->is_type(io::JSONNode::Type::LIST));
CHECK_EQ(
res.getObject()
->at("test")
->getObject()
->at("data_offsets")
->getList()
->size(),
2);
CHECK(res.getObject()
->at("test")
->getObject()
->at("data_offsets")
->getList()
->at(0)
->is_type(io::JSONNode::Type::NUMBER));
CHECK_EQ(
res.getObject()
->at("test")
->getObject()
->at("data_offsets")
->getList()
->at(0)
->getNumber(),
0);
CHECK(res.getObject()
->at("test")
->getObject()
->at("data_offsets")
->getList()
->at(1)
->is_type(io::JSONNode::Type::NUMBER));
CHECK_EQ(
res.getObject()
->at("test")
->getObject()
->at("data_offsets")
->getList()
->at(1)
->getNumber(),
16);
}
TEST_CASE("test load_safetensor") {
auto safeDict = load_safetensor("../../temp.safe");
CHECK_EQ(safeDict.size(), 1);
CHECK_EQ(safeDict.count("test"), 1);
array test = safeDict.at("test");
CHECK_EQ(test.dtype(), float32);
CHECK_EQ(test.shape(), std::vector<int>({4}));
}
TEST_CASE("test single array serialization") {