covering more cases

This commit is contained in:
dc-dc-dc 2023-12-17 21:13:10 -05:00
parent 87ec7b3cf9
commit bd422decc4
3 changed files with 174 additions and 71 deletions

View File

@ -36,17 +36,14 @@ Token Tokenizer::getToken() {
this->_loc++; this->_loc++;
return Token{TOKEN::COMMA}; return Token{TOKEN::COMMA};
case '"': { case '"': {
size_t start = this->_loc; size_t start = ++this->_loc;
this->_loc++; while (_data[++this->_loc] != '"' && this->hasMoreTokens())
while (_data[this->_loc] != '"' && this->hasMoreTokens()) { ;
this->_loc++;
}
if (!this->hasMoreTokens()) { if (!this->hasMoreTokens()) {
throw new std::runtime_error("no more chars to parse"); throw new std::runtime_error("no more chars to parse");
} }
// pass the last " // pass the last "
this->_loc++; return Token{TOKEN::STRING, start, ++this->_loc};
return Token{TOKEN::STRING, start, this->_loc};
} }
default: { default: {
size_t start = this->_loc; size_t start = this->_loc;
@ -63,66 +60,105 @@ Token Tokenizer::getToken() {
} }
} }
// JSONNode parseJson(char* data, size_t len) { JSONNode parseJson(const char* data, size_t len) {
// auto tokenizer = Tokenizer(data, len); auto tokenizer = Tokenizer(data, len);
// std::stack<JSONNode*> ctx; std::stack<JSONNode*> ctx;
// auto token = tokenizer.getToken(); while (tokenizer.hasMoreTokens()) {
// auto parent = new JSONNode(); auto token = tokenizer.getToken();
switch (token.type) {
case TOKEN::NULL_TYPE:
break;
case TOKEN::CURLY_OPEN:
ctx.push(new JSONNode(JSONNode::Type::OBJECT));
break;
case TOKEN::CURLY_CLOSE:
if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
auto obj = ctx.top();
ctx.pop();
// top-level object
if (ctx.size() == 0) {
return *obj;
}
// switch (token.type) { if (ctx.top()->is_type(JSONNode::Type::STRING)) {
// case TOKEN::CURLY_OPEN: auto key = ctx.top();
// parent->setObject(new JSONObject()); ctx.pop();
// break; if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
// case TOKEN::ARRAY_OPEN: ctx.top()->getObject()->insert({key->getString(), obj});
// parent->setList(new JSONList()); } else {
// break; throw new std::runtime_error("invalid json");
// default: }
// throw new std::runtime_error("invalid json"); } else if (ctx.top()->is_type(JSONNode::Type::LIST)) {
// } auto list = ctx.top()->getList();
// ctx.push(parent); list->push_back(obj);
}
}
break;
case TOKEN::ARRAY_OPEN:
ctx.push(new JSONNode(JSONNode::Type::LIST));
break;
case TOKEN::ARRAY_CLOSE:
if (ctx.top()->is_type(JSONNode::Type::STRING)) {
// key is above
auto key = ctx.top();
ctx.pop();
if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
ctx.top()->getObject()->insert({key->getString(), new JSONNode()});
}
} else if (ctx.top()->is_type(JSONNode::Type::LIST)) {
auto obj = ctx.top();
ctx.pop();
// top-level object
if (ctx.size() == 0) {
return *obj;
}
// while (tokenizer.hasMoreTokens()) { if (ctx.top()->is_type(JSONNode::Type::LIST)) {
// auto token = tokenizer.getToken(); auto list = ctx.top()->getList();
// switch (token.type) { list->push_back(obj);
// case TOKEN::CURLY_OPEN: }
// ctx.push(new JSONNode(JSONNode::Type::OBJECT)); }
// break; break;
// case TOKEN::CURLY_CLOSE: case TOKEN::STRING: {
// if (ctx.top()->is_type(JSONNode::Type::OBJECT)) { auto str =
// auto obj = ctx.top(); new std::string(data + token.start, token.end - token.start - 1);
// ctx.pop(); if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
// if (ctx.top()->is_type(JSONNode::Type::LIST)) { ctx.push(new JSONNode(str));
// auto list = ctx.top()->getList(); } else if (ctx.top()->is_type(JSONNode::Type::STRING)) {
// list->push_back(obj); auto key = ctx.top();
// } else if (ctx.top()->is_type(JSONNode::Type::STRING)) { ctx.pop();
// // if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
// auto key = ctx.top(); ctx.top()->getObject()->insert(
// ctx.pop(); {key->getString(), new JSONNode(str)});
// if (ctx.top()->is_type(JSONNode::Type::OBJECT)) { } else {
// ctx.top()->getObject()->insert({key->getString(), obj}); throw new std::runtime_error("invalid json");
// } }
// } } else if (ctx.top()->is_type(JSONNode::Type::LIST)) {
// } else { ctx.top()->getList()->push_back(new JSONNode(str));
// throw new std::runtime_error("invalid json"); }
// } break;
// break; }
// case TOKEN::COLON: case TOKEN::NUMBER:
// break; if (ctx.top()->is_type(JSONNode::Type::LIST)) {
// case TOKEN::ARRAY_OPEN: ctx.top()->getList()->push_back(new JSONNode(JSONNode::Type::NUMBER));
// break; } else if (ctx.top()->is_type(JSONNode::Type::STRING)) {
// case TOKEN::ARRAY_CLOSE: auto key = ctx.top();
// break; ctx.pop();
// case TOKEN::COMMA: if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
// break; ctx.top()->getObject()->insert(
// case TOKEN::NULL_TYPE: {key->getString(), new JSONNode(JSONNode::Type::NUMBER)});
// break; } else {
// case TOKEN::STRING: throw new std::runtime_error("invalid json");
// break; }
// case TOKEN::NUMBER: }
// break; break;
// } case TOKEN::COMMA:
// } break;
// } case TOKEN::COLON:
break;
}
}
}
} // namespace io } // namespace io
@ -141,7 +177,7 @@ std::map<std::string, array> load_safetensor(
in_stream->read(reinterpret_cast<char*>(&jsonHeaderLength), 8); in_stream->read(reinterpret_cast<char*>(&jsonHeaderLength), 8);
if (jsonHeaderLength <= 0) { if (jsonHeaderLength <= 0) {
throw std::runtime_error( throw std::runtime_error(
"[load_safetensor] Invalid json header lenght " + in_stream->label()); "[load_safetensor] Invalid json header length " + in_stream->label());
} }
// Load the json metadata // Load the json metadata
char json[jsonHeaderLength]; char json[jsonHeaderLength];

View File

@ -13,8 +13,8 @@ namespace mlx::core {
namespace io { namespace io {
class JSONNode; class JSONNode;
using JSONObject = std::map<std::string, std::shared_ptr<JSONNode>>; using JSONObject = std::map<std::string, JSONNode*>;
using JSONList = std::vector<std::shared_ptr<JSONNode>>; using JSONList = std::vector<JSONNode*>;
class JSONNode { class JSONNode {
public: public:
@ -29,11 +29,39 @@ class JSONNode {
this->_values.list = new JSONList(); this->_values.list = new JSONList();
} }
}; };
JSONNode(std::string* s) : _type(Type::STRING) {
this->_values.s = s;
};
JSONObject* getObject() {
if (!is_type(Type::OBJECT)) {
throw new std::runtime_error("not an object");
}
return this->_values.object;
}
JSONList* getList() {
if (!is_type(Type::LIST)) {
throw new std::runtime_error("not a list");
}
return this->_values.list;
}
std::string getString() {
if (!is_type(Type::STRING)) {
throw new std::runtime_error("not a string");
}
return *this->_values.s;
}
inline bool is_type(Type t) { inline bool is_type(Type t) {
return this->_type == t; return this->_type == t;
} }
inline Type type() const {
return this->_type;
}
private: private:
union Values { union Values {
JSONObject* object; JSONObject* object;
@ -44,6 +72,8 @@ class JSONNode {
Type _type; Type _type;
}; };
JSONNode parseJson(const char* data, size_t len);
enum class TOKEN { enum class TOKEN {
CURLY_OPEN, CURLY_OPEN,
CURLY_CLOSE, CURLY_CLOSE,

View File

@ -27,11 +27,48 @@ TEST_CASE("test tokenizer") {
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::ARRAY_CLOSE); CHECK_EQ(tokenizer.getToken().type, io::TOKEN::ARRAY_CLOSE);
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::CURLY_CLOSE); CHECK_EQ(tokenizer.getToken().type, io::TOKEN::CURLY_CLOSE);
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::NULL_TYPE); CHECK_EQ(tokenizer.getToken().type, io::TOKEN::NULL_TYPE);
raw = std::string(" { \"testing\": \"test\"} ");
tokenizer = io::Tokenizer(raw.c_str(), raw.size());
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::CURLY_OPEN);
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::STRING);
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::COLON);
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::STRING);
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::CURLY_CLOSE);
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::NULL_TYPE);
} }
// TEST_CASE("test load_safetensor") { TEST_CASE("test load_safetensor") {
// auto array = load_safetensor("../../temp.safe"); auto raw = std::string("{}");
// } auto res = io::parseJson(raw.c_str(), raw.size());
CHECK(res.is_type(io::JSONNode::Type::OBJECT));
raw = std::string("[]");
res = io::parseJson(raw.c_str(), raw.size());
CHECK(res.is_type(io::JSONNode::Type::LIST));
raw = std::string("[{}, \"test\"]");
res = io::parseJson(raw.c_str(), raw.size());
CHECK(res.is_type(io::JSONNode::Type::LIST));
CHECK_EQ(res.getList()->size(), 2);
CHECK(res.getList()->at(0)->is_type(io::JSONNode::Type::OBJECT));
CHECK(res.getList()->at(1)->is_type(io::JSONNode::Type::STRING));
MESSAGE(res.getList()->at(1)->getString());
CHECK_EQ(res.getList()->at(1)->getString(), "test");
raw = std::string("{\"test\": \"test\", \"test_num\": 1}");
res = io::parseJson(raw.c_str(), raw.size());
CHECK(res.is_type(io::JSONNode::Type::OBJECT));
CHECK_EQ(res.getObject()->size(), 2);
CHECK(res.getObject()->at("test")->is_type(io::JSONNode::Type::STRING));
CHECK_EQ(res.getObject()->at("test")->getString(), "test");
CHECK(res.getObject()->at("test_num")->is_type(io::JSONNode::Type::NUMBER));
raw = std::string("{\"test\": { \"test\": \"test\"}}");
res = io::parseJson(raw.c_str(), raw.size());
CHECK(res.is_type(io::JSONNode::Type::OBJECT));
CHECK_EQ(res.getObject()->size(), 1);
CHECK(res.getObject()->at("test")->is_type(io::JSONNode::Type::OBJECT));
CHECK_EQ(res.getObject()->at("test")->getObject()->size(), 1);
CHECK(res.getObject()->at("test")->getObject()->at("test")->is_type(
io::JSONNode::Type::STRING));
}
TEST_CASE("test single array serialization") { TEST_CASE("test single array serialization") {
// Basic test // Basic test