diff --git a/CMakeLists.txt b/CMakeLists.txt index 70293ebba..8e5a50fce 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -98,6 +98,10 @@ elseif (MLX_BUILD_METAL) ${QUARTZ_LIB}) endif() +MESSAGE(STATUS "Downloading json") +find_package(nlohmann_json 3.11.3 REQUIRED) +target_link_libraries(mlx nlohmann_json::nlohmann_json) + find_library(ACCELERATE_LIBRARY Accelerate) if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY) message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}") @@ -152,6 +156,8 @@ if (MLX_BUILD_BENCHMARKS) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp) endif() + + # ----------------------------- Installation ----------------------------- include(GNUInstallDirs) diff --git a/mlx/safetensor.cpp b/mlx/safetensor.cpp index 30f5d2b5a..1ec3f2964 100644 --- a/mlx/safetensor.cpp +++ b/mlx/safetensor.cpp @@ -4,200 +4,6 @@ namespace mlx::core { -namespace io { -Token Tokenizer::getToken() { - if (!this->hasMoreTokens()) { - return Token{TOKEN::NULL_TYPE}; - } - char nextChar = this->_data[this->_loc]; - while ((nextChar == ' ' || nextChar == '\n') && this->hasMoreTokens()) { - nextChar = this->_data[++this->_loc]; - } - if (!this->hasMoreTokens()) { - return Token{TOKEN::NULL_TYPE}; - } - // loc is not that important here, but need to increment location - // so might as well do it all in one line - switch (nextChar) { - case '{': - return Token{TOKEN::CURLY_OPEN, ++this->_loc}; - case '}': - return Token{TOKEN::CURLY_CLOSE, ++this->_loc}; - case ':': - return Token{TOKEN::COLON, ++this->_loc}; - case '[': - return Token{TOKEN::ARRAY_OPEN, ++this->_loc}; - case ']': - return Token{TOKEN::ARRAY_CLOSE, ++this->_loc}; - case ',': - return Token{TOKEN::COMMA, ++this->_loc}; - case '"': { - size_t start = ++this->_loc; - while (_data[++this->_loc] != '"' && this->hasMoreTokens()) - ; - if (!this->hasMoreTokens()) { - throw std::runtime_error("no more chars to parse"); - } - return Token{TOKEN::STRING, start, ++this->_loc}; - } - default: { - size_t start = this->_loc; - while ((nextChar != ',' && nextChar != '}' && nextChar != ']' && - nextChar != ' ' && nextChar != '\n') && - this->hasMoreTokens()) { - nextChar = this->_data[++this->_loc]; - } - if (!this->hasMoreTokens()) { - throw std::runtime_error("no more chars to parse"); - } - return Token{TOKEN::NUMBER, start, this->_loc}; - } - } -} - -JSONNode jsonDeserialize(const char* data, size_t len) { - auto tokenizer = Tokenizer(data, len); - std::stack ctx; - while (tokenizer.hasMoreTokens()) { - auto token = tokenizer.getToken(); - switch (token.type) { - case TOKEN::CURLY_OPEN: - ctx.push(new JSONNode(JSONNode::Type::OBJECT)); - break; - case TOKEN::ARRAY_OPEN: - ctx.push(new JSONNode(JSONNode::Type::LIST)); - break; - case TOKEN::CURLY_CLOSE: - if (ctx.top()->is_type(JSONNode::Type::OBJECT)) { - auto obj = ctx.top(); - ctx.pop(); - // top-level object - if (ctx.size() == 0) { - return *obj; - } - - if (ctx.top()->is_type(JSONNode::Type::STRING)) { - auto key = ctx.top(); - ctx.pop(); - if (ctx.top()->is_type(JSONNode::Type::OBJECT)) { - ctx.top()->getObject()->insert({key->getString(), obj}); - } else { - throw std::runtime_error("invalid json"); - } - } else if (ctx.top()->is_type(JSONNode::Type::LIST)) { - ctx.top()->getList()->push_back(obj); - } - } - break; - case TOKEN::ARRAY_CLOSE: - if (ctx.top()->is_type(JSONNode::Type::LIST)) { - auto obj = ctx.top(); - ctx.pop(); - if (ctx.size() == 0) { - return *obj; - } - if (ctx.top()->is_type(JSONNode::Type::STRING)) { - auto key = ctx.top(); - ctx.pop(); - if (ctx.top()->is_type(JSONNode::Type::OBJECT)) { - ctx.top()->getObject()->insert({key->getString(), obj}); - } else { - throw std::runtime_error( - "invalid json, string/array key pair did not have object parent"); - } - } else if (ctx.top()->is_type(JSONNode::Type::LIST)) { - if (ctx.top()->is_type(JSONNode::Type::LIST)) { - ctx.top()->getList()->push_back(obj); - } - } - } else { - throw std::runtime_error( - "invalid json, could not find array to close"); - } - break; - case TOKEN::STRING: { - auto str = - new std::string(data + token.start, token.end - token.start - 1); - if (ctx.top()->is_type(JSONNode::Type::LIST)) { - ctx.top()->getList()->push_back(new JSONNode(str)); - } else if (ctx.top()->is_type(JSONNode::Type::OBJECT)) { - ctx.push(new JSONNode(str)); - } else if (ctx.top()->is_type(JSONNode::Type::STRING)) { - auto key = ctx.top(); - ctx.pop(); - if (ctx.top()->is_type(JSONNode::Type::OBJECT)) { - ctx.top()->getObject()->insert( - {key->getString(), new JSONNode(str)}); - } else { - throw std::runtime_error("invalid json"); - } - } - break; - } - case TOKEN::NUMBER: { - // TODO: is there an easier way of doing this. - auto str = new std::string(data + token.start, token.end - token.start); - auto val = strtoul(str->c_str(), nullptr, 10); - if (ctx.top()->is_type(JSONNode::Type::LIST)) { - ctx.top()->getList()->push_back(new JSONNode(val)); - } else if (ctx.top()->is_type(JSONNode::Type::STRING)) { - auto key = ctx.top(); - ctx.pop(); - if (ctx.top()->is_type(JSONNode::Type::OBJECT)) { - ctx.top()->getObject()->insert( - {key->getString(), new JSONNode(val)}); - } else { - throw std::runtime_error("invalid json"); - } - } - break; - } - default: - break; - } - } - throw std::runtime_error( - "[jsonDeserialize] json was invalid and could not be parsed"); -} - -std::string jsonSerialize(JSONNode* node) { - std::string res; - if (node->is_type(JSONNode::Type::STRING)) { - return "\"" + node->getString() + "\""; - } - if (node->is_type(JSONNode::Type::NUMBER)) { - return std::to_string(node->getNumber()); - } - if (node->is_type(JSONNode::Type::LIST)) { - res += "["; - for (auto& item : *node->getList()) { - res += jsonSerialize(item); - res += ","; - } - if (res.back() == ',') { - res.pop_back(); - } - res += "]"; - return res; - } - if (node->is_type(JSONNode::Type::OBJECT)) { - res += "{"; - for (auto& [key, item] : *node->getObject()) { - res += "\"" + key + "\":"; - res += jsonSerialize(item); - res += ","; - } - if (res.back() == ',') { - res.pop_back(); - } - res += "}"; - return res; - } - - throw std::runtime_error("[jsonSerialize] invalid json node"); -} - -} // namespace io std::string dtype_to_safetensor_str(Dtype t) { if (t == float32) { return ST_F32; @@ -276,44 +82,34 @@ std::unordered_map load_safetensor( "[load_safetensor] Invalid json header length " + in_stream->label()); } // Load the json metadata - char json[jsonHeaderLength]; - in_stream->read(json, jsonHeaderLength); - auto metadata = io::jsonDeserialize(json, jsonHeaderLength); + char rawJson[jsonHeaderLength]; + in_stream->read(rawJson, jsonHeaderLength); + auto metadata = json::parse(rawJson, rawJson + jsonHeaderLength); + // auto metadata = io::jsonDeserialize(json, jsonHeaderLength); // Should always be an object on the top-level - if (!metadata.is_type(io::JSONNode::Type::OBJECT)) { + if (!metadata.is_object()) { throw std::runtime_error( "[load_safetensor] Invalid json metadata " + in_stream->label()); } size_t offset = jsonHeaderLength + 8; // Load the arrays using metadata std::unordered_map res; - for (auto& [key, obj] : *metadata.getObject()) { - if (key == "__metadata__") { + for (const auto& item : metadata.items()) { + if (item.key() == "__metadata__") { // ignore metadata for now continue; } - std::string dtype = obj->getObject()->at("dtype")->getString(); - auto shape = obj->getObject()->at("shape")->getList(); - std::vector shape_vec; - for (const auto& dim : *shape) { - shape_vec.push_back(dim->getNumber()); - } - auto data_offsets = obj->getObject()->at("data_offsets")->getList(); - std::vector data_offsets_vec; - for (const auto& offset : *data_offsets) { - data_offsets_vec.push_back(offset->getNumber()); - } + std::string dtype = item.value().at("dtype"); + std::vector shape = item.value().at("shape"); + std::vector data_offsets = item.value().at("data_offsets"); Dtype type = dtype_from_safetensor_str(dtype); auto loaded_array = array( - shape_vec, + shape, type, std::make_unique( - to_stream(s), - in_stream, - offset + data_offsets->at(0)->getNumber(), - false), + to_stream(s), in_stream, offset + data_offsets.at(0), false), std::vector{}); - res.insert({key, loaded_array}); + res.insert({item.key(), loaded_array}); } return res; } @@ -330,8 +126,8 @@ void save_safetensor( std::unordered_map a) { //////////////////////////////////////////////////////// // Check array map + json parent; - io::JSONNode metadata(io::JSONNode::Type::OBJECT); size_t offset = 0; for (auto& [key, arr] : a) { arr.eval(false); @@ -345,29 +141,12 @@ void save_safetensor( "[save_safetensor] cannot serialize a non-contiguous array key: " + key); } - auto obj = new io::JSONNode(io::JSONNode::Type::OBJECT); + json child; // TODO: dont make a new string - obj->getObject()->insert( - {"dtype", - new io::JSONNode( - new std::string(dtype_to_safetensor_str(arr.dtype())))}); - obj->getObject()->insert( - {"shape", new io::JSONNode(io::JSONNode::Type::LIST)}); - for (auto& dim : arr.shape()) { - obj->getObject()->at("shape")->getList()->push_back( - new io::JSONNode(dim)); - } - obj->getObject()->insert( - {"data_offsets", new io::JSONNode(io::JSONNode::Type::LIST)}); - obj->getObject() - ->at("data_offsets") - ->getList() - ->push_back(new io::JSONNode(offset)); - obj->getObject() - ->at("data_offsets") - ->getList() - ->push_back(new io::JSONNode(offset + arr.nbytes())); - metadata.getObject()->insert({key, obj}); + child["dtype"] = dtype_to_safetensor_str(arr.dtype()); + child["shape"] = arr.shape(); + child["data_offsets"] = std::vector{offset, offset + arr.nbytes()}; + parent[key] = child; offset += arr.nbytes(); } @@ -378,7 +157,7 @@ void save_safetensor( "[save_safetensor] Failed to open " + out_stream->label()); } - auto header = io::jsonSerialize(&metadata); + auto header = parent.dump(); uint64_t header_len = header.length(); out_stream->write(reinterpret_cast(&header_len), 8); out_stream->write(header.c_str(), header_len); @@ -393,7 +172,7 @@ void save_safetensor( // Open and check file std::string file = file_; - // Add .npy to file name if it is not there + // Add .safetensors to file name if it is not there if (file.length() < 12 || file.substr(file.length() - 12, 12) != ".safetensors") file += ".safetensors"; diff --git a/mlx/safetensor.h b/mlx/safetensor.h index 2e2b9c9c1..022c88913 100644 --- a/mlx/safetensor.h +++ b/mlx/safetensor.h @@ -2,11 +2,15 @@ #pragma once +#include + #include "mlx/load.h" #include "mlx/ops.h" #include "mlx/primitives.h" #include "mlx/utils.h" +using json = nlohmann::json; + namespace mlx::core { #define ST_F16 "F16" @@ -22,116 +26,4 @@ namespace mlx::core { #define ST_U16 "U16" #define ST_U32 "U32" #define ST_U64 "U64" - -namespace io { -// NOTE: This json parser is a bare minimum implementation for safetensors, -// it does not support all of json features, and does not have alot of edge case -// catches. This is okay as safe tensor json is very simple and we can assume it -// is always valid and well formed, but this should not be used for general json -// parsing -class JSONNode; -using JSONObject = std::unordered_map; -using JSONList = std::vector; - -class JSONNode { - public: - enum class Type { OBJECT, LIST, STRING, NUMBER, NULL_TYPE }; - - JSONNode() : _type(Type::NULL_TYPE){}; - JSONNode(Type type) : _type(type) { - // set the default value - if (type == Type::OBJECT) { - this->_values.object = new JSONObject(); - } else if (type == Type::LIST) { - this->_values.list = new JSONList(); - } - }; - JSONNode(std::string* s) : _type(Type::STRING) { - this->_values.s = s; - }; - JSONNode(float f) : _type(Type::NUMBER) { - this->_values.f = f; - }; - - JSONObject* getObject() { - if (!is_type(Type::OBJECT)) { - throw new std::runtime_error("not an object"); - } - return this->_values.object; - } - - JSONList* getList() { - if (!is_type(Type::LIST)) { - throw new std::runtime_error("not a list"); - } - return this->_values.list; - } - - std::string getString() { - if (!is_type(Type::STRING)) { - throw new std::runtime_error("not a string"); - } - return *this->_values.s; - } - - uint32_t getNumber() { - if (!is_type(Type::NUMBER)) { - throw new std::runtime_error("not a number"); - } - return this->_values.f; - } - - inline bool is_type(Type t) { - return this->_type == t; - } - - inline Type type() const { - return this->_type; - } - - private: - union Values { - JSONObject* object; - JSONList* list; - std::string* s; - uint64_t f; - } _values; - Type _type; -}; - -JSONNode jsonDeserialize(const char* data, size_t len); -std::string jsonSerialize(JSONNode* node); - -enum class TOKEN { - CURLY_OPEN, - CURLY_CLOSE, - COLON, - STRING, - NUMBER, - ARRAY_OPEN, - ARRAY_CLOSE, - COMMA, - NULL_TYPE, -}; - -struct Token { - TOKEN type; - size_t start; - size_t end; -}; - -class Tokenizer { - public: - Tokenizer(const char* data, size_t len) : _data(data), _loc(0), _len(len){}; - Token getToken(); - inline bool hasMoreTokens() { - return this->_loc < this->_len; - }; - - private: - const char* _data; - size_t _len; - size_t _loc; -}; -} // namespace io } // namespace mlx::core \ No newline at end of file diff --git a/tests/load_tests.cpp b/tests/load_tests.cpp index 60e06d670..1245f50f5 100644 --- a/tests/load_tests.cpp +++ b/tests/load_tests.cpp @@ -14,152 +14,6 @@ std::string get_temp_file(const std::string& name) { return std::filesystem::temp_directory_path().append(name); } -TEST_CASE("test tokenizer") { - auto raw = std::string(" { \"testing\": [1 , \"test\"]} "); - auto tokenizer = io::Tokenizer(raw.c_str(), raw.size()); - CHECK_EQ(tokenizer.getToken().type, io::TOKEN::CURLY_OPEN); - CHECK_EQ(tokenizer.getToken().type, io::TOKEN::STRING); - CHECK_EQ(tokenizer.getToken().type, io::TOKEN::COLON); - CHECK_EQ(tokenizer.getToken().type, io::TOKEN::ARRAY_OPEN); - CHECK_EQ(tokenizer.getToken().type, io::TOKEN::NUMBER); - CHECK_EQ(tokenizer.getToken().type, io::TOKEN::COMMA); - CHECK_EQ(tokenizer.getToken().type, io::TOKEN::STRING); - CHECK_EQ(tokenizer.getToken().type, io::TOKEN::ARRAY_CLOSE); - CHECK_EQ(tokenizer.getToken().type, io::TOKEN::CURLY_CLOSE); - CHECK_EQ(tokenizer.getToken().type, io::TOKEN::NULL_TYPE); - - raw = std::string(" { \"testing\": \"test\"} "); - tokenizer = io::Tokenizer(raw.c_str(), raw.size()); - CHECK_EQ(tokenizer.getToken().type, io::TOKEN::CURLY_OPEN); - CHECK_EQ(tokenizer.getToken().type, io::TOKEN::STRING); - CHECK_EQ(tokenizer.getToken().type, io::TOKEN::COLON); - CHECK_EQ(tokenizer.getToken().type, io::TOKEN::STRING); - CHECK_EQ(tokenizer.getToken().type, io::TOKEN::CURLY_CLOSE); - CHECK_EQ(tokenizer.getToken().type, io::TOKEN::NULL_TYPE); -} - -TEST_CASE("test jsonSerialize") { - auto test = new io::JSONNode(io::JSONNode::Type::OBJECT); - auto src = io::jsonSerialize(test); - CHECK_EQ(src, "{}"); - test = new io::JSONNode(io::JSONNode::Type::LIST); - src = io::jsonSerialize(test); - CHECK_EQ(src, "[]"); - test = new io::JSONNode(io::JSONNode::Type::OBJECT); - test->getObject()->insert( - {"test", new io::JSONNode(new std::string("testing"))}); - src = io::jsonSerialize(test); - CHECK_EQ(src, "{\"test\":\"testing\"}"); - test = new io::JSONNode(io::JSONNode::Type::OBJECT); - auto arr = new io::JSONNode(io::JSONNode::Type::LIST); - arr->getList()->push_back(new io::JSONNode(1)); - arr->getList()->push_back(new io::JSONNode(2)); - test->getObject()->insert({"test", arr}); - src = io::jsonSerialize(test); - CHECK_EQ(src, "{\"test\":[1,2]}"); -} - -TEST_CASE("test jsonDeserialize") { - auto raw = std::string("{}"); - auto res = io::jsonDeserialize(raw.c_str(), raw.size()); - CHECK(res.is_type(io::JSONNode::Type::OBJECT)); - - raw = std::string("[]"); - res = io::jsonDeserialize(raw.c_str(), raw.size()); - CHECK(res.is_type(io::JSONNode::Type::LIST)); - - raw = std::string("["); - CHECK_THROWS_AS( - io::jsonDeserialize(raw.c_str(), raw.size()), std::runtime_error); - - raw = std::string("[{}, \"test\"]"); - res = io::jsonDeserialize(raw.c_str(), raw.size()); - CHECK(res.is_type(io::JSONNode::Type::LIST)); - CHECK_EQ(res.getList()->size(), 2); - CHECK(res.getList()->at(0)->is_type(io::JSONNode::Type::OBJECT)); - CHECK(res.getList()->at(1)->is_type(io::JSONNode::Type::STRING)); - CHECK_EQ(res.getList()->at(1)->getString(), "test"); - - raw = std::string( - "{\"test\":{\"dtype\":\"F32\",\"shape\":[4], \"data_offsets\":[0, 16]}}"); - res = io::jsonDeserialize(raw.c_str(), raw.size()); - CHECK(res.is_type(io::JSONNode::Type::OBJECT)); - CHECK_EQ(res.getObject()->size(), 1); - CHECK(res.getObject()->at("test")->is_type(io::JSONNode::Type::OBJECT)); - CHECK_EQ(res.getObject()->at("test")->getObject()->size(), 3); - CHECK(res.getObject()->at("test")->getObject()->at("dtype")->is_type( - io::JSONNode::Type::STRING)); - CHECK_EQ( - res.getObject()->at("test")->getObject()->at("dtype")->getString(), - "F32"); - CHECK(res.getObject()->at("test")->getObject()->at("shape")->is_type( - io::JSONNode::Type::LIST)); - CHECK_EQ( - res.getObject()->at("test")->getObject()->at("shape")->getList()->size(), - 1); - CHECK(res.getObject() - ->at("test") - ->getObject() - ->at("shape") - ->getList() - ->at(0) - ->is_type(io::JSONNode::Type::NUMBER)); - CHECK_EQ( - res.getObject() - ->at("test") - ->getObject() - ->at("shape") - ->getList() - ->at(0) - ->getNumber(), - 4); - CHECK(res.getObject() - ->at("test") - ->getObject() - ->at("data_offsets") - ->is_type(io::JSONNode::Type::LIST)); - CHECK_EQ( - res.getObject() - ->at("test") - ->getObject() - ->at("data_offsets") - ->getList() - ->size(), - 2); - CHECK(res.getObject() - ->at("test") - ->getObject() - ->at("data_offsets") - ->getList() - ->at(0) - ->is_type(io::JSONNode::Type::NUMBER)); - CHECK_EQ( - res.getObject() - ->at("test") - ->getObject() - ->at("data_offsets") - ->getList() - ->at(0) - ->getNumber(), - 0); - CHECK(res.getObject() - ->at("test") - ->getObject() - ->at("data_offsets") - ->getList() - ->at(1) - ->is_type(io::JSONNode::Type::NUMBER)); - CHECK_EQ( - res.getObject() - ->at("test") - ->getObject() - ->at("data_offsets") - ->getList() - ->at(1) - ->getNumber(), - 16); -} - TEST_CASE("test save_safetensor") { std::string file_path = get_temp_file("test_arr.safetensors"); auto map = std::unordered_map();