mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
fixed array parsing
This commit is contained in:
parent
91495382fd
commit
fef579cec1
@ -83,32 +83,36 @@ JSONNode parseJson(const char* data, size_t len) {
|
|||||||
throw std::runtime_error("invalid json");
|
throw std::runtime_error("invalid json");
|
||||||
}
|
}
|
||||||
} else if (ctx.top()->is_type(JSONNode::Type::LIST)) {
|
} else if (ctx.top()->is_type(JSONNode::Type::LIST)) {
|
||||||
auto list = ctx.top()->getList();
|
ctx.top()->getList()->push_back(obj);
|
||||||
list->push_back(obj);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case TOKEN::ARRAY_CLOSE:
|
case TOKEN::ARRAY_CLOSE:
|
||||||
|
if (ctx.top()->is_type(JSONNode::Type::LIST)) {
|
||||||
|
auto obj = ctx.top();
|
||||||
|
ctx.pop();
|
||||||
|
if (ctx.size() == 0) {
|
||||||
|
return *obj;
|
||||||
|
}
|
||||||
if (ctx.top()->is_type(JSONNode::Type::STRING)) {
|
if (ctx.top()->is_type(JSONNode::Type::STRING)) {
|
||||||
// key is above
|
// key is above
|
||||||
auto key = ctx.top();
|
auto key = ctx.top();
|
||||||
ctx.pop();
|
ctx.pop();
|
||||||
if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
|
if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
|
||||||
ctx.top()->getObject()->insert({key->getString(), new JSONNode()});
|
ctx.top()->getObject()->insert({key->getString(), obj});
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"invalid json, string/array key pair did not have object parent");
|
||||||
}
|
}
|
||||||
} else if (ctx.top()->is_type(JSONNode::Type::LIST)) {
|
} else if (ctx.top()->is_type(JSONNode::Type::LIST)) {
|
||||||
auto obj = ctx.top();
|
|
||||||
ctx.pop();
|
|
||||||
// top-level object
|
|
||||||
if (ctx.size() == 0) {
|
|
||||||
return *obj;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ctx.top()->is_type(JSONNode::Type::LIST)) {
|
if (ctx.top()->is_type(JSONNode::Type::LIST)) {
|
||||||
auto list = ctx.top()->getList();
|
ctx.top()->getList()->push_back(obj);
|
||||||
list->push_back(obj);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"invalid json, could not find array to close");
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
case TOKEN::STRING: {
|
case TOKEN::STRING: {
|
||||||
auto str =
|
auto str =
|
||||||
@ -155,7 +159,7 @@ JSONNode parseJson(const char* data, size_t len) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
throw std::runtime_error("invalid json");
|
throw std::runtime_error("[unreachable] invalid json");
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace io
|
} // namespace io
|
||||||
@ -187,6 +191,22 @@ std::map<std::string, array> load_safetensor(
|
|||||||
}
|
}
|
||||||
// Parse the json raw data
|
// Parse the json raw data
|
||||||
std::map<std::string, array> res;
|
std::map<std::string, array> res;
|
||||||
|
for (const auto& key : *metadata.getObject()) {
|
||||||
|
std::string dtype = key.second->getObject()->at("dtype")->getString();
|
||||||
|
auto shape = key.second->getObject()->at("shape")->getList();
|
||||||
|
std::vector<int> shape_vec;
|
||||||
|
for (const auto& dim : *shape) {
|
||||||
|
shape_vec.push_back(dim->getNumber());
|
||||||
|
}
|
||||||
|
auto data_offsets = key.second->getObject()->at("data_offsets")->getList();
|
||||||
|
std::vector<int64_t> data_offsets_vec;
|
||||||
|
for (const auto& offset : *data_offsets) {
|
||||||
|
data_offsets_vec.push_back(offset->getNumber());
|
||||||
|
}
|
||||||
|
if (dtype == "F32") {
|
||||||
|
res.insert({key.first, zeros(shape_vec, s)});
|
||||||
|
}
|
||||||
|
}
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -56,7 +56,6 @@ TEST_CASE("test parseJson") {
|
|||||||
CHECK_EQ(res.getList()->size(), 2);
|
CHECK_EQ(res.getList()->size(), 2);
|
||||||
CHECK(res.getList()->at(0)->is_type(io::JSONNode::Type::OBJECT));
|
CHECK(res.getList()->at(0)->is_type(io::JSONNode::Type::OBJECT));
|
||||||
CHECK(res.getList()->at(1)->is_type(io::JSONNode::Type::STRING));
|
CHECK(res.getList()->at(1)->is_type(io::JSONNode::Type::STRING));
|
||||||
MESSAGE(res.getList()->at(1)->getString());
|
|
||||||
CHECK_EQ(res.getList()->at(1)->getString(), "test");
|
CHECK_EQ(res.getList()->at(1)->getString(), "test");
|
||||||
|
|
||||||
raw = std::string("{\"test\": \"test\", \"test_num\": 1}");
|
raw = std::string("{\"test\": \"test\", \"test_num\": 1}");
|
||||||
@ -76,6 +75,106 @@ TEST_CASE("test parseJson") {
|
|||||||
CHECK_EQ(res.getObject()->at("test")->getObject()->size(), 1);
|
CHECK_EQ(res.getObject()->at("test")->getObject()->size(), 1);
|
||||||
CHECK(res.getObject()->at("test")->getObject()->at("test")->is_type(
|
CHECK(res.getObject()->at("test")->getObject()->at("test")->is_type(
|
||||||
io::JSONNode::Type::STRING));
|
io::JSONNode::Type::STRING));
|
||||||
|
|
||||||
|
raw = std::string("{\"test\":[1, 2]}");
|
||||||
|
res = io::parseJson(raw.c_str(), raw.size());
|
||||||
|
CHECK(res.is_type(io::JSONNode::Type::OBJECT));
|
||||||
|
CHECK_EQ(res.getObject()->size(), 1);
|
||||||
|
CHECK(res.getObject()->at("test")->is_type(io::JSONNode::Type::LIST));
|
||||||
|
CHECK_EQ(res.getObject()->at("test")->getList()->size(), 2);
|
||||||
|
CHECK(res.getObject()->at("test")->getList()->at(0)->is_type(
|
||||||
|
io::JSONNode::Type::NUMBER));
|
||||||
|
CHECK_EQ(res.getObject()->at("test")->getList()->at(0)->getNumber(), 1);
|
||||||
|
CHECK(res.getObject()->at("test")->getList()->at(1)->is_type(
|
||||||
|
io::JSONNode::Type::NUMBER));
|
||||||
|
CHECK_EQ(res.getObject()->at("test")->getList()->at(1)->getNumber(), 2);
|
||||||
|
raw = std::string(
|
||||||
|
"{\"test\":{\"dtype\":\"F32\",\"shape\":[4], \"data_offsets\":[0, 16]}}");
|
||||||
|
res = io::parseJson(raw.c_str(), raw.size());
|
||||||
|
CHECK(res.is_type(io::JSONNode::Type::OBJECT));
|
||||||
|
CHECK_EQ(res.getObject()->size(), 1);
|
||||||
|
CHECK(res.getObject()->at("test")->is_type(io::JSONNode::Type::OBJECT));
|
||||||
|
CHECK_EQ(res.getObject()->at("test")->getObject()->size(), 3);
|
||||||
|
CHECK(res.getObject()->at("test")->getObject()->at("dtype")->is_type(
|
||||||
|
io::JSONNode::Type::STRING));
|
||||||
|
CHECK_EQ(
|
||||||
|
res.getObject()->at("test")->getObject()->at("dtype")->getString(),
|
||||||
|
"F32");
|
||||||
|
CHECK(res.getObject()->at("test")->getObject()->at("shape")->is_type(
|
||||||
|
io::JSONNode::Type::LIST));
|
||||||
|
CHECK_EQ(
|
||||||
|
res.getObject()->at("test")->getObject()->at("shape")->getList()->size(),
|
||||||
|
1);
|
||||||
|
CHECK(res.getObject()
|
||||||
|
->at("test")
|
||||||
|
->getObject()
|
||||||
|
->at("shape")
|
||||||
|
->getList()
|
||||||
|
->at(0)
|
||||||
|
->is_type(io::JSONNode::Type::NUMBER));
|
||||||
|
CHECK_EQ(
|
||||||
|
res.getObject()
|
||||||
|
->at("test")
|
||||||
|
->getObject()
|
||||||
|
->at("shape")
|
||||||
|
->getList()
|
||||||
|
->at(0)
|
||||||
|
->getNumber(),
|
||||||
|
4);
|
||||||
|
CHECK(res.getObject()
|
||||||
|
->at("test")
|
||||||
|
->getObject()
|
||||||
|
->at("data_offsets")
|
||||||
|
->is_type(io::JSONNode::Type::LIST));
|
||||||
|
CHECK_EQ(
|
||||||
|
res.getObject()
|
||||||
|
->at("test")
|
||||||
|
->getObject()
|
||||||
|
->at("data_offsets")
|
||||||
|
->getList()
|
||||||
|
->size(),
|
||||||
|
2);
|
||||||
|
CHECK(res.getObject()
|
||||||
|
->at("test")
|
||||||
|
->getObject()
|
||||||
|
->at("data_offsets")
|
||||||
|
->getList()
|
||||||
|
->at(0)
|
||||||
|
->is_type(io::JSONNode::Type::NUMBER));
|
||||||
|
CHECK_EQ(
|
||||||
|
res.getObject()
|
||||||
|
->at("test")
|
||||||
|
->getObject()
|
||||||
|
->at("data_offsets")
|
||||||
|
->getList()
|
||||||
|
->at(0)
|
||||||
|
->getNumber(),
|
||||||
|
0);
|
||||||
|
CHECK(res.getObject()
|
||||||
|
->at("test")
|
||||||
|
->getObject()
|
||||||
|
->at("data_offsets")
|
||||||
|
->getList()
|
||||||
|
->at(1)
|
||||||
|
->is_type(io::JSONNode::Type::NUMBER));
|
||||||
|
CHECK_EQ(
|
||||||
|
res.getObject()
|
||||||
|
->at("test")
|
||||||
|
->getObject()
|
||||||
|
->at("data_offsets")
|
||||||
|
->getList()
|
||||||
|
->at(1)
|
||||||
|
->getNumber(),
|
||||||
|
16);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test load_safetensor") {
|
||||||
|
auto safeDict = load_safetensor("../../temp.safe");
|
||||||
|
CHECK_EQ(safeDict.size(), 1);
|
||||||
|
CHECK_EQ(safeDict.count("test"), 1);
|
||||||
|
array test = safeDict.at("test");
|
||||||
|
CHECK_EQ(test.dtype(), float32);
|
||||||
|
CHECK_EQ(test.shape(), std::vector<int>({4}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test single array serialization") {
|
TEST_CASE("test single array serialization") {
|
||||||
|
Loading…
Reference in New Issue
Block a user