mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:38:07 +08:00
fixed array parsing
This commit is contained in:
parent
91495382fd
commit
fef579cec1
@ -83,32 +83,36 @@ JSONNode parseJson(const char* data, size_t len) {
|
||||
throw std::runtime_error("invalid json");
|
||||
}
|
||||
} else if (ctx.top()->is_type(JSONNode::Type::LIST)) {
|
||||
auto list = ctx.top()->getList();
|
||||
list->push_back(obj);
|
||||
ctx.top()->getList()->push_back(obj);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case TOKEN::ARRAY_CLOSE:
|
||||
if (ctx.top()->is_type(JSONNode::Type::LIST)) {
|
||||
auto obj = ctx.top();
|
||||
ctx.pop();
|
||||
if (ctx.size() == 0) {
|
||||
return *obj;
|
||||
}
|
||||
if (ctx.top()->is_type(JSONNode::Type::STRING)) {
|
||||
// key is above
|
||||
auto key = ctx.top();
|
||||
ctx.pop();
|
||||
if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
|
||||
ctx.top()->getObject()->insert({key->getString(), new JSONNode()});
|
||||
ctx.top()->getObject()->insert({key->getString(), obj});
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"invalid json, string/array key pair did not have object parent");
|
||||
}
|
||||
} else if (ctx.top()->is_type(JSONNode::Type::LIST)) {
|
||||
auto obj = ctx.top();
|
||||
ctx.pop();
|
||||
// top-level object
|
||||
if (ctx.size() == 0) {
|
||||
return *obj;
|
||||
}
|
||||
|
||||
if (ctx.top()->is_type(JSONNode::Type::LIST)) {
|
||||
auto list = ctx.top()->getList();
|
||||
list->push_back(obj);
|
||||
ctx.top()->getList()->push_back(obj);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"invalid json, could not find array to close");
|
||||
}
|
||||
break;
|
||||
case TOKEN::STRING: {
|
||||
auto str =
|
||||
@ -155,7 +159,7 @@ JSONNode parseJson(const char* data, size_t len) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
throw std::runtime_error("invalid json");
|
||||
throw std::runtime_error("[unreachable] invalid json");
|
||||
}
|
||||
|
||||
} // namespace io
|
||||
@ -187,6 +191,22 @@ std::map<std::string, array> load_safetensor(
|
||||
}
|
||||
// Parse the json raw data
|
||||
std::map<std::string, array> res;
|
||||
for (const auto& key : *metadata.getObject()) {
|
||||
std::string dtype = key.second->getObject()->at("dtype")->getString();
|
||||
auto shape = key.second->getObject()->at("shape")->getList();
|
||||
std::vector<int> shape_vec;
|
||||
for (const auto& dim : *shape) {
|
||||
shape_vec.push_back(dim->getNumber());
|
||||
}
|
||||
auto data_offsets = key.second->getObject()->at("data_offsets")->getList();
|
||||
std::vector<int64_t> data_offsets_vec;
|
||||
for (const auto& offset : *data_offsets) {
|
||||
data_offsets_vec.push_back(offset->getNumber());
|
||||
}
|
||||
if (dtype == "F32") {
|
||||
res.insert({key.first, zeros(shape_vec, s)});
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
|
@ -56,7 +56,6 @@ TEST_CASE("test parseJson") {
|
||||
CHECK_EQ(res.getList()->size(), 2);
|
||||
CHECK(res.getList()->at(0)->is_type(io::JSONNode::Type::OBJECT));
|
||||
CHECK(res.getList()->at(1)->is_type(io::JSONNode::Type::STRING));
|
||||
MESSAGE(res.getList()->at(1)->getString());
|
||||
CHECK_EQ(res.getList()->at(1)->getString(), "test");
|
||||
|
||||
raw = std::string("{\"test\": \"test\", \"test_num\": 1}");
|
||||
@ -76,6 +75,106 @@ TEST_CASE("test parseJson") {
|
||||
CHECK_EQ(res.getObject()->at("test")->getObject()->size(), 1);
|
||||
CHECK(res.getObject()->at("test")->getObject()->at("test")->is_type(
|
||||
io::JSONNode::Type::STRING));
|
||||
|
||||
raw = std::string("{\"test\":[1, 2]}");
|
||||
res = io::parseJson(raw.c_str(), raw.size());
|
||||
CHECK(res.is_type(io::JSONNode::Type::OBJECT));
|
||||
CHECK_EQ(res.getObject()->size(), 1);
|
||||
CHECK(res.getObject()->at("test")->is_type(io::JSONNode::Type::LIST));
|
||||
CHECK_EQ(res.getObject()->at("test")->getList()->size(), 2);
|
||||
CHECK(res.getObject()->at("test")->getList()->at(0)->is_type(
|
||||
io::JSONNode::Type::NUMBER));
|
||||
CHECK_EQ(res.getObject()->at("test")->getList()->at(0)->getNumber(), 1);
|
||||
CHECK(res.getObject()->at("test")->getList()->at(1)->is_type(
|
||||
io::JSONNode::Type::NUMBER));
|
||||
CHECK_EQ(res.getObject()->at("test")->getList()->at(1)->getNumber(), 2);
|
||||
raw = std::string(
|
||||
"{\"test\":{\"dtype\":\"F32\",\"shape\":[4], \"data_offsets\":[0, 16]}}");
|
||||
res = io::parseJson(raw.c_str(), raw.size());
|
||||
CHECK(res.is_type(io::JSONNode::Type::OBJECT));
|
||||
CHECK_EQ(res.getObject()->size(), 1);
|
||||
CHECK(res.getObject()->at("test")->is_type(io::JSONNode::Type::OBJECT));
|
||||
CHECK_EQ(res.getObject()->at("test")->getObject()->size(), 3);
|
||||
CHECK(res.getObject()->at("test")->getObject()->at("dtype")->is_type(
|
||||
io::JSONNode::Type::STRING));
|
||||
CHECK_EQ(
|
||||
res.getObject()->at("test")->getObject()->at("dtype")->getString(),
|
||||
"F32");
|
||||
CHECK(res.getObject()->at("test")->getObject()->at("shape")->is_type(
|
||||
io::JSONNode::Type::LIST));
|
||||
CHECK_EQ(
|
||||
res.getObject()->at("test")->getObject()->at("shape")->getList()->size(),
|
||||
1);
|
||||
CHECK(res.getObject()
|
||||
->at("test")
|
||||
->getObject()
|
||||
->at("shape")
|
||||
->getList()
|
||||
->at(0)
|
||||
->is_type(io::JSONNode::Type::NUMBER));
|
||||
CHECK_EQ(
|
||||
res.getObject()
|
||||
->at("test")
|
||||
->getObject()
|
||||
->at("shape")
|
||||
->getList()
|
||||
->at(0)
|
||||
->getNumber(),
|
||||
4);
|
||||
CHECK(res.getObject()
|
||||
->at("test")
|
||||
->getObject()
|
||||
->at("data_offsets")
|
||||
->is_type(io::JSONNode::Type::LIST));
|
||||
CHECK_EQ(
|
||||
res.getObject()
|
||||
->at("test")
|
||||
->getObject()
|
||||
->at("data_offsets")
|
||||
->getList()
|
||||
->size(),
|
||||
2);
|
||||
CHECK(res.getObject()
|
||||
->at("test")
|
||||
->getObject()
|
||||
->at("data_offsets")
|
||||
->getList()
|
||||
->at(0)
|
||||
->is_type(io::JSONNode::Type::NUMBER));
|
||||
CHECK_EQ(
|
||||
res.getObject()
|
||||
->at("test")
|
||||
->getObject()
|
||||
->at("data_offsets")
|
||||
->getList()
|
||||
->at(0)
|
||||
->getNumber(),
|
||||
0);
|
||||
CHECK(res.getObject()
|
||||
->at("test")
|
||||
->getObject()
|
||||
->at("data_offsets")
|
||||
->getList()
|
||||
->at(1)
|
||||
->is_type(io::JSONNode::Type::NUMBER));
|
||||
CHECK_EQ(
|
||||
res.getObject()
|
||||
->at("test")
|
||||
->getObject()
|
||||
->at("data_offsets")
|
||||
->getList()
|
||||
->at(1)
|
||||
->getNumber(),
|
||||
16);
|
||||
}
|
||||
|
||||
TEST_CASE("test load_safetensor") {
|
||||
auto safeDict = load_safetensor("../../temp.safe");
|
||||
CHECK_EQ(safeDict.size(), 1);
|
||||
CHECK_EQ(safeDict.count("test"), 1);
|
||||
array test = safeDict.at("test");
|
||||
CHECK_EQ(test.dtype(), float32);
|
||||
CHECK_EQ(test.shape(), std::vector<int>({4}));
|
||||
}
|
||||
|
||||
TEST_CASE("test single array serialization") {
|
||||
|
Loading…
Reference in New Issue
Block a user