diff --git a/mlx/safetensor.cpp b/mlx/safetensor.cpp index 07296c592..318a3d082 100644 --- a/mlx/safetensor.cpp +++ b/mlx/safetensor.cpp @@ -36,7 +36,6 @@ Token Tokenizer::getToken() { if (!this->hasMoreTokens()) { throw new std::runtime_error("no more chars to parse"); } - // pass the last " return Token{TOKEN::STRING, start, ++this->_loc}; } default: { @@ -132,20 +131,24 @@ JSONNode parseJson(const char* data, size_t len) { } break; } - case TOKEN::NUMBER: + case TOKEN::NUMBER: { + // TODO: is there an easier way of doing this. + auto str = new std::string(data + token.start, token.end - token.start); + float val = strtof(str->c_str(), nullptr); if (ctx.top()->is_type(JSONNode::Type::LIST)) { - ctx.top()->getList()->push_back(new JSONNode(JSONNode::Type::NUMBER)); + ctx.top()->getList()->push_back(new JSONNode(val)); } else if (ctx.top()->is_type(JSONNode::Type::STRING)) { auto key = ctx.top(); ctx.pop(); if (ctx.top()->is_type(JSONNode::Type::OBJECT)) { ctx.top()->getObject()->insert( - {key->getString(), new JSONNode(JSONNode::Type::NUMBER)}); + {key->getString(), new JSONNode(val)}); } else { throw new std::runtime_error("invalid json"); } } break; + } case TOKEN::COMMA: break; case TOKEN::COLON: diff --git a/mlx/safetensor.h b/mlx/safetensor.h index 272d29979..2fb5957c0 100644 --- a/mlx/safetensor.h +++ b/mlx/safetensor.h @@ -32,6 +32,9 @@ class JSONNode { JSONNode(std::string* s) : _type(Type::STRING) { this->_values.s = s; }; + JSONNode(float f) : _type(Type::NUMBER) { + this->_values.f = f; + }; JSONObject* getObject() { if (!is_type(Type::OBJECT)) { @@ -54,6 +57,13 @@ class JSONNode { return *this->_values.s; } + float getNumber() { + if (!is_type(Type::NUMBER)) { + throw new std::runtime_error("not a number"); + } + return this->_values.f; + } + inline bool is_type(Type t) { return this->_type == t; } @@ -67,7 +77,7 @@ class JSONNode { JSONObject* object; JSONList* list; std::string* s; - float fValue; + float f; } _values; Type _type; }; diff --git a/tests/load_tests.cpp b/tests/load_tests.cpp index 70cddcc6e..73fa905dd 100644 --- a/tests/load_tests.cpp +++ b/tests/load_tests.cpp @@ -38,13 +38,15 @@ TEST_CASE("test tokenizer") { CHECK_EQ(tokenizer.getToken().type, io::TOKEN::NULL_TYPE); } -TEST_CASE("test load_safetensor") { +TEST_CASE("test parseJson") { auto raw = std::string("{}"); auto res = io::parseJson(raw.c_str(), raw.size()); CHECK(res.is_type(io::JSONNode::Type::OBJECT)); + raw = std::string("[]"); res = io::parseJson(raw.c_str(), raw.size()); CHECK(res.is_type(io::JSONNode::Type::LIST)); + raw = std::string("[{}, \"test\"]"); res = io::parseJson(raw.c_str(), raw.size()); CHECK(res.is_type(io::JSONNode::Type::LIST)); @@ -53,6 +55,7 @@ TEST_CASE("test load_safetensor") { CHECK(res.getList()->at(1)->is_type(io::JSONNode::Type::STRING)); MESSAGE(res.getList()->at(1)->getString()); CHECK_EQ(res.getList()->at(1)->getString(), "test"); + raw = std::string("{\"test\": \"test\", \"test_num\": 1}"); res = io::parseJson(raw.c_str(), raw.size()); CHECK(res.is_type(io::JSONNode::Type::OBJECT)); @@ -60,6 +63,8 @@ TEST_CASE("test load_safetensor") { CHECK(res.getObject()->at("test")->is_type(io::JSONNode::Type::STRING)); CHECK_EQ(res.getObject()->at("test")->getString(), "test"); CHECK(res.getObject()->at("test_num")->is_type(io::JSONNode::Type::NUMBER)); + CHECK_EQ(res.getObject()->at("test_num")->getNumber(), 1); + raw = std::string("{\"test\": { \"test\": \"test\"}}"); res = io::parseJson(raw.c_str(), raw.size()); CHECK(res.is_type(io::JSONNode::Type::OBJECT));