mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
more cleanup
This commit is contained in:
parent
d0285db98c
commit
91495382fd
@ -56,7 +56,6 @@ Token Tokenizer::getToken() {
|
|||||||
JSONNode parseJson(const char* data, size_t len) {
|
JSONNode parseJson(const char* data, size_t len) {
|
||||||
auto tokenizer = Tokenizer(data, len);
|
auto tokenizer = Tokenizer(data, len);
|
||||||
std::stack<JSONNode*> ctx;
|
std::stack<JSONNode*> ctx;
|
||||||
std::string key;
|
|
||||||
while (tokenizer.hasMoreTokens()) {
|
while (tokenizer.hasMoreTokens()) {
|
||||||
auto token = tokenizer.getToken();
|
auto token = tokenizer.getToken();
|
||||||
switch (token.type) {
|
switch (token.type) {
|
||||||
@ -81,7 +80,7 @@ JSONNode parseJson(const char* data, size_t len) {
|
|||||||
if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
|
if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
|
||||||
ctx.top()->getObject()->insert({key->getString(), obj});
|
ctx.top()->getObject()->insert({key->getString(), obj});
|
||||||
} else {
|
} else {
|
||||||
throw new std::runtime_error("invalid json");
|
throw std::runtime_error("invalid json");
|
||||||
}
|
}
|
||||||
} else if (ctx.top()->is_type(JSONNode::Type::LIST)) {
|
} else if (ctx.top()->is_type(JSONNode::Type::LIST)) {
|
||||||
auto list = ctx.top()->getList();
|
auto list = ctx.top()->getList();
|
||||||
@ -89,7 +88,6 @@ JSONNode parseJson(const char* data, size_t len) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case TOKEN::ARRAY_CLOSE:
|
case TOKEN::ARRAY_CLOSE:
|
||||||
if (ctx.top()->is_type(JSONNode::Type::STRING)) {
|
if (ctx.top()->is_type(JSONNode::Type::STRING)) {
|
||||||
// key is above
|
// key is above
|
||||||
@ -115,7 +113,9 @@ JSONNode parseJson(const char* data, size_t len) {
|
|||||||
case TOKEN::STRING: {
|
case TOKEN::STRING: {
|
||||||
auto str =
|
auto str =
|
||||||
new std::string(data + token.start, token.end - token.start - 1);
|
new std::string(data + token.start, token.end - token.start - 1);
|
||||||
if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
|
if (ctx.top()->is_type(JSONNode::Type::LIST)) {
|
||||||
|
ctx.top()->getList()->push_back(new JSONNode(str));
|
||||||
|
} else if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
|
||||||
ctx.push(new JSONNode(str));
|
ctx.push(new JSONNode(str));
|
||||||
} else if (ctx.top()->is_type(JSONNode::Type::STRING)) {
|
} else if (ctx.top()->is_type(JSONNode::Type::STRING)) {
|
||||||
auto key = ctx.top();
|
auto key = ctx.top();
|
||||||
@ -124,10 +124,8 @@ JSONNode parseJson(const char* data, size_t len) {
|
|||||||
ctx.top()->getObject()->insert(
|
ctx.top()->getObject()->insert(
|
||||||
{key->getString(), new JSONNode(str)});
|
{key->getString(), new JSONNode(str)});
|
||||||
} else {
|
} else {
|
||||||
throw new std::runtime_error("invalid json");
|
throw std::runtime_error("invalid json");
|
||||||
}
|
}
|
||||||
} else if (ctx.top()->is_type(JSONNode::Type::LIST)) {
|
|
||||||
ctx.top()->getList()->push_back(new JSONNode(str));
|
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -144,7 +142,7 @@ JSONNode parseJson(const char* data, size_t len) {
|
|||||||
ctx.top()->getObject()->insert(
|
ctx.top()->getObject()->insert(
|
||||||
{key->getString(), new JSONNode(val)});
|
{key->getString(), new JSONNode(val)});
|
||||||
} else {
|
} else {
|
||||||
throw new std::runtime_error("invalid json");
|
throw std::runtime_error("invalid json");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
@ -157,6 +155,7 @@ JSONNode parseJson(const char* data, size_t len) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
throw std::runtime_error("invalid json");
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace io
|
} // namespace io
|
||||||
@ -181,6 +180,11 @@ std::map<std::string, array> load_safetensor(
|
|||||||
// Load the json metadata
|
// Load the json metadata
|
||||||
char json[jsonHeaderLength];
|
char json[jsonHeaderLength];
|
||||||
in_stream->read(json, jsonHeaderLength);
|
in_stream->read(json, jsonHeaderLength);
|
||||||
|
auto metadata = io::parseJson(json, jsonHeaderLength);
|
||||||
|
if (!metadata.is_type(io::JSONNode::Type::OBJECT)) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[load_safetensor] Invalid json metadata " + in_stream->label());
|
||||||
|
}
|
||||||
// Parse the json raw data
|
// Parse the json raw data
|
||||||
std::map<std::string, array> res;
|
std::map<std::string, array> res;
|
||||||
return res;
|
return res;
|
||||||
|
@ -18,7 +18,7 @@ using JSONList = std::vector<JSONNode*>;
|
|||||||
|
|
||||||
class JSONNode {
|
class JSONNode {
|
||||||
public:
|
public:
|
||||||
enum class Type { OBJECT, LIST, STRING, NUMBER, BOOLEAN, NULL_TYPE };
|
enum class Type { OBJECT, LIST, STRING, NUMBER, NULL_TYPE };
|
||||||
|
|
||||||
JSONNode() : _type(Type::NULL_TYPE){};
|
JSONNode() : _type(Type::NULL_TYPE){};
|
||||||
JSONNode(Type type) : _type(type) {
|
JSONNode(Type type) : _type(type) {
|
||||||
|
@ -47,6 +47,9 @@ TEST_CASE("test parseJson") {
|
|||||||
res = io::parseJson(raw.c_str(), raw.size());
|
res = io::parseJson(raw.c_str(), raw.size());
|
||||||
CHECK(res.is_type(io::JSONNode::Type::LIST));
|
CHECK(res.is_type(io::JSONNode::Type::LIST));
|
||||||
|
|
||||||
|
raw = std::string("[");
|
||||||
|
CHECK_THROWS_AS(io::parseJson(raw.c_str(), raw.size()), std::runtime_error);
|
||||||
|
|
||||||
raw = std::string("[{}, \"test\"]");
|
raw = std::string("[{}, \"test\"]");
|
||||||
res = io::parseJson(raw.c_str(), raw.size());
|
res = io::parseJson(raw.c_str(), raw.size());
|
||||||
CHECK(res.is_type(io::JSONNode::Type::LIST));
|
CHECK(res.is_type(io::JSONNode::Type::LIST));
|
||||||
|
Loading…
Reference in New Issue
Block a user