From fef579cec1f18b85acec7ba7bbe929d2aef49683 Mon Sep 17 00:00:00 2001 From: dc-dc-dc Date: Mon, 18 Dec 2023 10:28:41 -0500 Subject: [PATCH] fixed array parsing --- mlx/safetensor.cpp | 52 +++++++++++++++------- tests/load_tests.cpp | 101 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 136 insertions(+), 17 deletions(-) diff --git a/mlx/safetensor.cpp b/mlx/safetensor.cpp index 75679a4fc..d986da6e7 100644 --- a/mlx/safetensor.cpp +++ b/mlx/safetensor.cpp @@ -83,31 +83,35 @@ JSONNode parseJson(const char* data, size_t len) { throw std::runtime_error("invalid json"); } } else if (ctx.top()->is_type(JSONNode::Type::LIST)) { - auto list = ctx.top()->getList(); - list->push_back(obj); + ctx.top()->getList()->push_back(obj); } } break; case TOKEN::ARRAY_CLOSE: - if (ctx.top()->is_type(JSONNode::Type::STRING)) { - // key is above - auto key = ctx.top(); - ctx.pop(); - if (ctx.top()->is_type(JSONNode::Type::OBJECT)) { - ctx.top()->getObject()->insert({key->getString(), new JSONNode()}); - } - } else if (ctx.top()->is_type(JSONNode::Type::LIST)) { + if (ctx.top()->is_type(JSONNode::Type::LIST)) { auto obj = ctx.top(); ctx.pop(); - // top-level object if (ctx.size() == 0) { return *obj; } - - if (ctx.top()->is_type(JSONNode::Type::LIST)) { - auto list = ctx.top()->getList(); - list->push_back(obj); + if (ctx.top()->is_type(JSONNode::Type::STRING)) { + // key is above + auto key = ctx.top(); + ctx.pop(); + if (ctx.top()->is_type(JSONNode::Type::OBJECT)) { + ctx.top()->getObject()->insert({key->getString(), obj}); + } else { + throw std::runtime_error( + "invalid json, string/array key pair did not have object parent"); + } + } else if (ctx.top()->is_type(JSONNode::Type::LIST)) { + if (ctx.top()->is_type(JSONNode::Type::LIST)) { + ctx.top()->getList()->push_back(obj); + } } + } else { + throw std::runtime_error( + "invalid json, could not find array to close"); } break; case TOKEN::STRING: { @@ -155,7 +159,7 @@ JSONNode parseJson(const char* data, size_t len) { break; } } - throw std::runtime_error("invalid json"); + throw std::runtime_error("[unreachable] invalid json"); } } // namespace io @@ -187,6 +191,22 @@ std::map load_safetensor( } // Parse the json raw data std::map res; + for (const auto& key : *metadata.getObject()) { + std::string dtype = key.second->getObject()->at("dtype")->getString(); + auto shape = key.second->getObject()->at("shape")->getList(); + std::vector shape_vec; + for (const auto& dim : *shape) { + shape_vec.push_back(dim->getNumber()); + } + auto data_offsets = key.second->getObject()->at("data_offsets")->getList(); + std::vector data_offsets_vec; + for (const auto& offset : *data_offsets) { + data_offsets_vec.push_back(offset->getNumber()); + } + if (dtype == "F32") { + res.insert({key.first, zeros(shape_vec, s)}); + } + } return res; } diff --git a/tests/load_tests.cpp b/tests/load_tests.cpp index 5eae0276a..06f003c2a 100644 --- a/tests/load_tests.cpp +++ b/tests/load_tests.cpp @@ -56,7 +56,6 @@ TEST_CASE("test parseJson") { CHECK_EQ(res.getList()->size(), 2); CHECK(res.getList()->at(0)->is_type(io::JSONNode::Type::OBJECT)); CHECK(res.getList()->at(1)->is_type(io::JSONNode::Type::STRING)); - MESSAGE(res.getList()->at(1)->getString()); CHECK_EQ(res.getList()->at(1)->getString(), "test"); raw = std::string("{\"test\": \"test\", \"test_num\": 1}"); @@ -76,6 +75,106 @@ TEST_CASE("test parseJson") { CHECK_EQ(res.getObject()->at("test")->getObject()->size(), 1); CHECK(res.getObject()->at("test")->getObject()->at("test")->is_type( io::JSONNode::Type::STRING)); + + raw = std::string("{\"test\":[1, 2]}"); + res = io::parseJson(raw.c_str(), raw.size()); + CHECK(res.is_type(io::JSONNode::Type::OBJECT)); + CHECK_EQ(res.getObject()->size(), 1); + CHECK(res.getObject()->at("test")->is_type(io::JSONNode::Type::LIST)); + CHECK_EQ(res.getObject()->at("test")->getList()->size(), 2); + CHECK(res.getObject()->at("test")->getList()->at(0)->is_type( + io::JSONNode::Type::NUMBER)); + CHECK_EQ(res.getObject()->at("test")->getList()->at(0)->getNumber(), 1); + CHECK(res.getObject()->at("test")->getList()->at(1)->is_type( + io::JSONNode::Type::NUMBER)); + CHECK_EQ(res.getObject()->at("test")->getList()->at(1)->getNumber(), 2); + raw = std::string( + "{\"test\":{\"dtype\":\"F32\",\"shape\":[4], \"data_offsets\":[0, 16]}}"); + res = io::parseJson(raw.c_str(), raw.size()); + CHECK(res.is_type(io::JSONNode::Type::OBJECT)); + CHECK_EQ(res.getObject()->size(), 1); + CHECK(res.getObject()->at("test")->is_type(io::JSONNode::Type::OBJECT)); + CHECK_EQ(res.getObject()->at("test")->getObject()->size(), 3); + CHECK(res.getObject()->at("test")->getObject()->at("dtype")->is_type( + io::JSONNode::Type::STRING)); + CHECK_EQ( + res.getObject()->at("test")->getObject()->at("dtype")->getString(), + "F32"); + CHECK(res.getObject()->at("test")->getObject()->at("shape")->is_type( + io::JSONNode::Type::LIST)); + CHECK_EQ( + res.getObject()->at("test")->getObject()->at("shape")->getList()->size(), + 1); + CHECK(res.getObject() + ->at("test") + ->getObject() + ->at("shape") + ->getList() + ->at(0) + ->is_type(io::JSONNode::Type::NUMBER)); + CHECK_EQ( + res.getObject() + ->at("test") + ->getObject() + ->at("shape") + ->getList() + ->at(0) + ->getNumber(), + 4); + CHECK(res.getObject() + ->at("test") + ->getObject() + ->at("data_offsets") + ->is_type(io::JSONNode::Type::LIST)); + CHECK_EQ( + res.getObject() + ->at("test") + ->getObject() + ->at("data_offsets") + ->getList() + ->size(), + 2); + CHECK(res.getObject() + ->at("test") + ->getObject() + ->at("data_offsets") + ->getList() + ->at(0) + ->is_type(io::JSONNode::Type::NUMBER)); + CHECK_EQ( + res.getObject() + ->at("test") + ->getObject() + ->at("data_offsets") + ->getList() + ->at(0) + ->getNumber(), + 0); + CHECK(res.getObject() + ->at("test") + ->getObject() + ->at("data_offsets") + ->getList() + ->at(1) + ->is_type(io::JSONNode::Type::NUMBER)); + CHECK_EQ( + res.getObject() + ->at("test") + ->getObject() + ->at("data_offsets") + ->getList() + ->at(1) + ->getNumber(), + 16); +} + +TEST_CASE("test load_safetensor") { + auto safeDict = load_safetensor("../../temp.safe"); + CHECK_EQ(safeDict.size(), 1); + CHECK_EQ(safeDict.count("test"), 1); + array test = safeDict.at("test"); + CHECK_EQ(test.dtype(), float32); + CHECK_EQ(test.shape(), std::vector({4})); } TEST_CASE("test single array serialization") {