mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
covering more cases
This commit is contained in:
parent
87ec7b3cf9
commit
bd422decc4
@ -36,17 +36,14 @@ Token Tokenizer::getToken() {
|
|||||||
this->_loc++;
|
this->_loc++;
|
||||||
return Token{TOKEN::COMMA};
|
return Token{TOKEN::COMMA};
|
||||||
case '"': {
|
case '"': {
|
||||||
size_t start = this->_loc;
|
size_t start = ++this->_loc;
|
||||||
this->_loc++;
|
while (_data[++this->_loc] != '"' && this->hasMoreTokens())
|
||||||
while (_data[this->_loc] != '"' && this->hasMoreTokens()) {
|
;
|
||||||
this->_loc++;
|
|
||||||
}
|
|
||||||
if (!this->hasMoreTokens()) {
|
if (!this->hasMoreTokens()) {
|
||||||
throw new std::runtime_error("no more chars to parse");
|
throw new std::runtime_error("no more chars to parse");
|
||||||
}
|
}
|
||||||
// pass the last "
|
// pass the last "
|
||||||
this->_loc++;
|
return Token{TOKEN::STRING, start, ++this->_loc};
|
||||||
return Token{TOKEN::STRING, start, this->_loc};
|
|
||||||
}
|
}
|
||||||
default: {
|
default: {
|
||||||
size_t start = this->_loc;
|
size_t start = this->_loc;
|
||||||
@ -63,66 +60,105 @@ Token Tokenizer::getToken() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// JSONNode parseJson(char* data, size_t len) {
|
JSONNode parseJson(const char* data, size_t len) {
|
||||||
// auto tokenizer = Tokenizer(data, len);
|
auto tokenizer = Tokenizer(data, len);
|
||||||
// std::stack<JSONNode*> ctx;
|
std::stack<JSONNode*> ctx;
|
||||||
// auto token = tokenizer.getToken();
|
while (tokenizer.hasMoreTokens()) {
|
||||||
// auto parent = new JSONNode();
|
auto token = tokenizer.getToken();
|
||||||
|
switch (token.type) {
|
||||||
|
case TOKEN::NULL_TYPE:
|
||||||
|
break;
|
||||||
|
case TOKEN::CURLY_OPEN:
|
||||||
|
ctx.push(new JSONNode(JSONNode::Type::OBJECT));
|
||||||
|
break;
|
||||||
|
case TOKEN::CURLY_CLOSE:
|
||||||
|
if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
|
||||||
|
auto obj = ctx.top();
|
||||||
|
ctx.pop();
|
||||||
|
// top-level object
|
||||||
|
if (ctx.size() == 0) {
|
||||||
|
return *obj;
|
||||||
|
}
|
||||||
|
|
||||||
// switch (token.type) {
|
if (ctx.top()->is_type(JSONNode::Type::STRING)) {
|
||||||
// case TOKEN::CURLY_OPEN:
|
auto key = ctx.top();
|
||||||
// parent->setObject(new JSONObject());
|
ctx.pop();
|
||||||
// break;
|
if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
|
||||||
// case TOKEN::ARRAY_OPEN:
|
ctx.top()->getObject()->insert({key->getString(), obj});
|
||||||
// parent->setList(new JSONList());
|
} else {
|
||||||
// break;
|
throw new std::runtime_error("invalid json");
|
||||||
// default:
|
}
|
||||||
// throw new std::runtime_error("invalid json");
|
} else if (ctx.top()->is_type(JSONNode::Type::LIST)) {
|
||||||
// }
|
auto list = ctx.top()->getList();
|
||||||
// ctx.push(parent);
|
list->push_back(obj);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case TOKEN::ARRAY_OPEN:
|
||||||
|
ctx.push(new JSONNode(JSONNode::Type::LIST));
|
||||||
|
break;
|
||||||
|
case TOKEN::ARRAY_CLOSE:
|
||||||
|
if (ctx.top()->is_type(JSONNode::Type::STRING)) {
|
||||||
|
// key is above
|
||||||
|
auto key = ctx.top();
|
||||||
|
ctx.pop();
|
||||||
|
if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
|
||||||
|
ctx.top()->getObject()->insert({key->getString(), new JSONNode()});
|
||||||
|
}
|
||||||
|
} else if (ctx.top()->is_type(JSONNode::Type::LIST)) {
|
||||||
|
auto obj = ctx.top();
|
||||||
|
ctx.pop();
|
||||||
|
// top-level object
|
||||||
|
if (ctx.size() == 0) {
|
||||||
|
return *obj;
|
||||||
|
}
|
||||||
|
|
||||||
// while (tokenizer.hasMoreTokens()) {
|
if (ctx.top()->is_type(JSONNode::Type::LIST)) {
|
||||||
// auto token = tokenizer.getToken();
|
auto list = ctx.top()->getList();
|
||||||
// switch (token.type) {
|
list->push_back(obj);
|
||||||
// case TOKEN::CURLY_OPEN:
|
}
|
||||||
// ctx.push(new JSONNode(JSONNode::Type::OBJECT));
|
}
|
||||||
// break;
|
break;
|
||||||
// case TOKEN::CURLY_CLOSE:
|
case TOKEN::STRING: {
|
||||||
// if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
|
auto str =
|
||||||
// auto obj = ctx.top();
|
new std::string(data + token.start, token.end - token.start - 1);
|
||||||
// ctx.pop();
|
if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
|
||||||
// if (ctx.top()->is_type(JSONNode::Type::LIST)) {
|
ctx.push(new JSONNode(str));
|
||||||
// auto list = ctx.top()->getList();
|
} else if (ctx.top()->is_type(JSONNode::Type::STRING)) {
|
||||||
// list->push_back(obj);
|
auto key = ctx.top();
|
||||||
// } else if (ctx.top()->is_type(JSONNode::Type::STRING)) {
|
ctx.pop();
|
||||||
// //
|
if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
|
||||||
// auto key = ctx.top();
|
ctx.top()->getObject()->insert(
|
||||||
// ctx.pop();
|
{key->getString(), new JSONNode(str)});
|
||||||
// if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
|
} else {
|
||||||
// ctx.top()->getObject()->insert({key->getString(), obj});
|
throw new std::runtime_error("invalid json");
|
||||||
// }
|
}
|
||||||
// }
|
} else if (ctx.top()->is_type(JSONNode::Type::LIST)) {
|
||||||
// } else {
|
ctx.top()->getList()->push_back(new JSONNode(str));
|
||||||
// throw new std::runtime_error("invalid json");
|
}
|
||||||
// }
|
break;
|
||||||
// break;
|
}
|
||||||
// case TOKEN::COLON:
|
case TOKEN::NUMBER:
|
||||||
// break;
|
if (ctx.top()->is_type(JSONNode::Type::LIST)) {
|
||||||
// case TOKEN::ARRAY_OPEN:
|
ctx.top()->getList()->push_back(new JSONNode(JSONNode::Type::NUMBER));
|
||||||
// break;
|
} else if (ctx.top()->is_type(JSONNode::Type::STRING)) {
|
||||||
// case TOKEN::ARRAY_CLOSE:
|
auto key = ctx.top();
|
||||||
// break;
|
ctx.pop();
|
||||||
// case TOKEN::COMMA:
|
if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
|
||||||
// break;
|
ctx.top()->getObject()->insert(
|
||||||
// case TOKEN::NULL_TYPE:
|
{key->getString(), new JSONNode(JSONNode::Type::NUMBER)});
|
||||||
// break;
|
} else {
|
||||||
// case TOKEN::STRING:
|
throw new std::runtime_error("invalid json");
|
||||||
// break;
|
}
|
||||||
// case TOKEN::NUMBER:
|
}
|
||||||
// break;
|
break;
|
||||||
// }
|
case TOKEN::COMMA:
|
||||||
// }
|
break;
|
||||||
// }
|
case TOKEN::COLON:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace io
|
} // namespace io
|
||||||
|
|
||||||
@ -141,7 +177,7 @@ std::map<std::string, array> load_safetensor(
|
|||||||
in_stream->read(reinterpret_cast<char*>(&jsonHeaderLength), 8);
|
in_stream->read(reinterpret_cast<char*>(&jsonHeaderLength), 8);
|
||||||
if (jsonHeaderLength <= 0) {
|
if (jsonHeaderLength <= 0) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[load_safetensor] Invalid json header lenght " + in_stream->label());
|
"[load_safetensor] Invalid json header length " + in_stream->label());
|
||||||
}
|
}
|
||||||
// Load the json metadata
|
// Load the json metadata
|
||||||
char json[jsonHeaderLength];
|
char json[jsonHeaderLength];
|
||||||
|
@ -13,8 +13,8 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace io {
|
namespace io {
|
||||||
class JSONNode;
|
class JSONNode;
|
||||||
using JSONObject = std::map<std::string, std::shared_ptr<JSONNode>>;
|
using JSONObject = std::map<std::string, JSONNode*>;
|
||||||
using JSONList = std::vector<std::shared_ptr<JSONNode>>;
|
using JSONList = std::vector<JSONNode*>;
|
||||||
|
|
||||||
class JSONNode {
|
class JSONNode {
|
||||||
public:
|
public:
|
||||||
@ -29,11 +29,39 @@ class JSONNode {
|
|||||||
this->_values.list = new JSONList();
|
this->_values.list = new JSONList();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
JSONNode(std::string* s) : _type(Type::STRING) {
|
||||||
|
this->_values.s = s;
|
||||||
|
};
|
||||||
|
|
||||||
|
JSONObject* getObject() {
|
||||||
|
if (!is_type(Type::OBJECT)) {
|
||||||
|
throw new std::runtime_error("not an object");
|
||||||
|
}
|
||||||
|
return this->_values.object;
|
||||||
|
}
|
||||||
|
|
||||||
|
JSONList* getList() {
|
||||||
|
if (!is_type(Type::LIST)) {
|
||||||
|
throw new std::runtime_error("not a list");
|
||||||
|
}
|
||||||
|
return this->_values.list;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string getString() {
|
||||||
|
if (!is_type(Type::STRING)) {
|
||||||
|
throw new std::runtime_error("not a string");
|
||||||
|
}
|
||||||
|
return *this->_values.s;
|
||||||
|
}
|
||||||
|
|
||||||
inline bool is_type(Type t) {
|
inline bool is_type(Type t) {
|
||||||
return this->_type == t;
|
return this->_type == t;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline Type type() const {
|
||||||
|
return this->_type;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
union Values {
|
union Values {
|
||||||
JSONObject* object;
|
JSONObject* object;
|
||||||
@ -44,6 +72,8 @@ class JSONNode {
|
|||||||
Type _type;
|
Type _type;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
JSONNode parseJson(const char* data, size_t len);
|
||||||
|
|
||||||
enum class TOKEN {
|
enum class TOKEN {
|
||||||
CURLY_OPEN,
|
CURLY_OPEN,
|
||||||
CURLY_CLOSE,
|
CURLY_CLOSE,
|
||||||
|
@ -27,11 +27,48 @@ TEST_CASE("test tokenizer") {
|
|||||||
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::ARRAY_CLOSE);
|
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::ARRAY_CLOSE);
|
||||||
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::CURLY_CLOSE);
|
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::CURLY_CLOSE);
|
||||||
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::NULL_TYPE);
|
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::NULL_TYPE);
|
||||||
|
|
||||||
|
raw = std::string(" { \"testing\": \"test\"} ");
|
||||||
|
tokenizer = io::Tokenizer(raw.c_str(), raw.size());
|
||||||
|
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::CURLY_OPEN);
|
||||||
|
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::STRING);
|
||||||
|
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::COLON);
|
||||||
|
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::STRING);
|
||||||
|
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::CURLY_CLOSE);
|
||||||
|
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::NULL_TYPE);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TEST_CASE("test load_safetensor") {
|
TEST_CASE("test load_safetensor") {
|
||||||
// auto array = load_safetensor("../../temp.safe");
|
auto raw = std::string("{}");
|
||||||
// }
|
auto res = io::parseJson(raw.c_str(), raw.size());
|
||||||
|
CHECK(res.is_type(io::JSONNode::Type::OBJECT));
|
||||||
|
raw = std::string("[]");
|
||||||
|
res = io::parseJson(raw.c_str(), raw.size());
|
||||||
|
CHECK(res.is_type(io::JSONNode::Type::LIST));
|
||||||
|
raw = std::string("[{}, \"test\"]");
|
||||||
|
res = io::parseJson(raw.c_str(), raw.size());
|
||||||
|
CHECK(res.is_type(io::JSONNode::Type::LIST));
|
||||||
|
CHECK_EQ(res.getList()->size(), 2);
|
||||||
|
CHECK(res.getList()->at(0)->is_type(io::JSONNode::Type::OBJECT));
|
||||||
|
CHECK(res.getList()->at(1)->is_type(io::JSONNode::Type::STRING));
|
||||||
|
MESSAGE(res.getList()->at(1)->getString());
|
||||||
|
CHECK_EQ(res.getList()->at(1)->getString(), "test");
|
||||||
|
raw = std::string("{\"test\": \"test\", \"test_num\": 1}");
|
||||||
|
res = io::parseJson(raw.c_str(), raw.size());
|
||||||
|
CHECK(res.is_type(io::JSONNode::Type::OBJECT));
|
||||||
|
CHECK_EQ(res.getObject()->size(), 2);
|
||||||
|
CHECK(res.getObject()->at("test")->is_type(io::JSONNode::Type::STRING));
|
||||||
|
CHECK_EQ(res.getObject()->at("test")->getString(), "test");
|
||||||
|
CHECK(res.getObject()->at("test_num")->is_type(io::JSONNode::Type::NUMBER));
|
||||||
|
raw = std::string("{\"test\": { \"test\": \"test\"}}");
|
||||||
|
res = io::parseJson(raw.c_str(), raw.size());
|
||||||
|
CHECK(res.is_type(io::JSONNode::Type::OBJECT));
|
||||||
|
CHECK_EQ(res.getObject()->size(), 1);
|
||||||
|
CHECK(res.getObject()->at("test")->is_type(io::JSONNode::Type::OBJECT));
|
||||||
|
CHECK_EQ(res.getObject()->at("test")->getObject()->size(), 1);
|
||||||
|
CHECK(res.getObject()->at("test")->getObject()->at("test")->is_type(
|
||||||
|
io::JSONNode::Type::STRING));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_CASE("test single array serialization") {
|
TEST_CASE("test single array serialization") {
|
||||||
// Basic test
|
// Basic test
|
||||||
|
Loading…
Reference in New Issue
Block a user