diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index bd28537f1..e7e708bd4 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -11,6 +11,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/safetensor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp diff --git a/mlx/mlx.h b/mlx/mlx.h index 102d2dde9..c67684ab2 100644 --- a/mlx/mlx.h +++ b/mlx/mlx.h @@ -8,6 +8,7 @@ #include "mlx/fft.h" #include "mlx/ops.h" #include "mlx/random.h" +#include "mlx/safetensor.h" #include "mlx/stream.h" #include "mlx/transforms.h" #include "mlx/utils.h" diff --git a/mlx/ops.h b/mlx/ops.h index fe59d4e49..fed4a4adb 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1057,4 +1057,12 @@ array dequantize( int bits = 4, StreamOrDevice s = {}); +/** Load array from .safetensor file format */ +std::map load_safetensor( + std::shared_ptr in_stream, + StreamOrDevice s = {}); +std::map load_safetensor( + const std::string& file, + StreamOrDevice s = {}); + } // namespace mlx::core diff --git a/mlx/safetensor.cpp b/mlx/safetensor.cpp new file mode 100644 index 000000000..0e4838fa0 --- /dev/null +++ b/mlx/safetensor.cpp @@ -0,0 +1,160 @@ +#include "mlx/safetensor.h" + +#include + +namespace mlx::core { + +namespace io { +Token Tokenizer::getToken() { + if (!this->hasMoreTokens()) { + return Token{TOKEN::NULL_TYPE}; + } + char nextChar = this->_data[this->_loc]; + while ((nextChar == ' ' || nextChar == '\n') && this->hasMoreTokens()) { + nextChar = this->_data[++this->_loc]; + } + if (!this->hasMoreTokens()) { + return Token{TOKEN::NULL_TYPE}; + } + switch (nextChar) { + case '{': + this->_loc++; + return Token{TOKEN::CURLY_OPEN}; + case '}': + this->_loc++; + return Token{TOKEN::CURLY_CLOSE}; + case ':': + this->_loc++; + return Token{TOKEN::COLON}; + case '[': + this->_loc++; + return Token{TOKEN::ARRAY_OPEN}; + case ']': + this->_loc++; + return Token{TOKEN::ARRAY_CLOSE}; + case ',': + this->_loc++; + return Token{TOKEN::COMMA}; + case '"': { + size_t start = this->_loc; + this->_loc++; + while (_data[this->_loc] != '"' && this->hasMoreTokens()) { + this->_loc++; + } + if (!this->hasMoreTokens()) { + throw new std::runtime_error("no more chars to parse"); + } + // pass the last " + this->_loc++; + return Token{TOKEN::STRING, start, this->_loc}; + } + default: { + size_t start = this->_loc; + while ((nextChar != ',' && nextChar != '}' && nextChar != ']' && + nextChar != ' ' && nextChar != '\n') && + this->hasMoreTokens()) { + nextChar = this->_data[++this->_loc]; + } + if (!this->hasMoreTokens()) { + throw new std::runtime_error("no more chars to parse"); + } + return Token{TOKEN::NUMBER, start, this->_loc}; + } + } +} + +// JSONNode parseJson(char* data, size_t len) { +// auto tokenizer = Tokenizer(data, len); +// std::stack ctx; +// auto token = tokenizer.getToken(); +// auto parent = new JSONNode(); + +// switch (token.type) { +// case TOKEN::CURLY_OPEN: +// parent->setObject(new JSONObject()); +// break; +// case TOKEN::ARRAY_OPEN: +// parent->setList(new JSONList()); +// break; +// default: +// throw new std::runtime_error("invalid json"); +// } +// ctx.push(parent); + +// while (tokenizer.hasMoreTokens()) { +// auto token = tokenizer.getToken(); +// switch (token.type) { +// case TOKEN::CURLY_OPEN: +// ctx.push(new JSONNode(JSONNode::Type::OBJECT)); +// break; +// case TOKEN::CURLY_CLOSE: +// if (ctx.top()->is_type(JSONNode::Type::OBJECT)) { +// auto obj = ctx.top(); +// ctx.pop(); +// if (ctx.top()->is_type(JSONNode::Type::LIST)) { +// auto list = ctx.top()->getList(); +// list->push_back(obj); +// } else if (ctx.top()->is_type(JSONNode::Type::STRING)) { +// // +// auto key = ctx.top(); +// ctx.pop(); +// if (ctx.top()->is_type(JSONNode::Type::OBJECT)) { +// ctx.top()->getObject()->insert({key->getString(), obj}); +// } +// } +// } else { +// throw new std::runtime_error("invalid json"); +// } +// break; +// case TOKEN::COLON: +// break; +// case TOKEN::ARRAY_OPEN: +// break; +// case TOKEN::ARRAY_CLOSE: +// break; +// case TOKEN::COMMA: +// break; +// case TOKEN::NULL_TYPE: +// break; +// case TOKEN::STRING: +// break; +// case TOKEN::NUMBER: +// break; +// } +// } +// } + +} // namespace io + +/** Load array from reader in safetensor format */ +std::map load_safetensor( + std::shared_ptr in_stream, + StreamOrDevice s) { + //////////////////////////////////////////////////////// + // Open and check file + if (!in_stream->good() || !in_stream->is_open()) { + throw std::runtime_error( + "[load_safetensor] Failed to open " + in_stream->label()); + } + + uint64_t jsonHeaderLength = 0; + in_stream->read(reinterpret_cast(&jsonHeaderLength), 8); + if (jsonHeaderLength <= 0) { + throw std::runtime_error( + "[load_safetensor] Invalid json header lenght " + in_stream->label()); + } + // Load the json metadata + char json[jsonHeaderLength]; + in_stream->read(json, jsonHeaderLength); + // Parse the json raw data + std::map res; + return res; +} + +std::map load_safetensor( + const std::string& file, + StreamOrDevice s) { + return load_safetensor(std::make_shared(file), s); +} + +} // namespace mlx::core \ No newline at end of file diff --git a/mlx/safetensor.h b/mlx/safetensor.h new file mode 100644 index 000000000..097711816 --- /dev/null +++ b/mlx/safetensor.h @@ -0,0 +1,79 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +#include "mlx/load.h" +#include "mlx/ops.h" +#include "mlx/primitives.h" +#include "mlx/utils.h" + +namespace mlx::core { + +namespace io { +class JSONNode; +using JSONObject = std::map>; +using JSONList = std::vector>; + +class JSONNode { + public: + enum class Type { OBJECT, LIST, STRING, NUMBER, BOOLEAN, NULL_TYPE }; + + JSONNode() : _type(Type::NULL_TYPE){}; + JSONNode(Type type) : _type(type) { + // set the default value + if (type == Type::OBJECT) { + this->_values.object = new JSONObject(); + } else if (type == Type::LIST) { + this->_values.list = new JSONList(); + } + }; + + inline bool is_type(Type t) { + return this->_type == t; + } + + private: + union Values { + JSONObject* object; + JSONList* list; + std::string* s; + float fValue; + } _values; + Type _type; +}; + +enum class TOKEN { + CURLY_OPEN, + CURLY_CLOSE, + COLON, + STRING, + NUMBER, + ARRAY_OPEN, + ARRAY_CLOSE, + COMMA, + NULL_TYPE, +}; + +struct Token { + TOKEN type; + size_t start; + size_t end; +}; + +class Tokenizer { + public: + Tokenizer(const char* data, size_t len) : _data(data), _loc(0), _len(len){}; + Token getToken(); + inline bool hasMoreTokens() { + return this->_loc < this->_len; + }; + + private: + const char* _data; + size_t _len; + size_t _loc; +}; +} // namespace io +} // namespace mlx::core \ No newline at end of file diff --git a/temp.safe b/temp.safe new file mode 100644 index 000000000..ecba8b7b5 Binary files /dev/null and b/temp.safe differ diff --git a/tests/load_tests.cpp b/tests/load_tests.cpp index f2489ca72..471313aa0 100644 --- a/tests/load_tests.cpp +++ b/tests/load_tests.cpp @@ -14,6 +14,25 @@ std::string get_temp_file(const std::string& name) { return std::filesystem::temp_directory_path().append(name); } +TEST_CASE("test tokenizer") { + auto raw = std::string(" { \"testing\": [1 , \"test\"]} "); + auto tokenizer = io::Tokenizer(raw.c_str(), raw.size()); + CHECK_EQ(tokenizer.getToken().type, io::TOKEN::CURLY_OPEN); + CHECK_EQ(tokenizer.getToken().type, io::TOKEN::STRING); + CHECK_EQ(tokenizer.getToken().type, io::TOKEN::COLON); + CHECK_EQ(tokenizer.getToken().type, io::TOKEN::ARRAY_OPEN); + CHECK_EQ(tokenizer.getToken().type, io::TOKEN::NUMBER); + CHECK_EQ(tokenizer.getToken().type, io::TOKEN::COMMA); + CHECK_EQ(tokenizer.getToken().type, io::TOKEN::STRING); + CHECK_EQ(tokenizer.getToken().type, io::TOKEN::ARRAY_CLOSE); + CHECK_EQ(tokenizer.getToken().type, io::TOKEN::CURLY_CLOSE); + CHECK_EQ(tokenizer.getToken().type, io::TOKEN::NULL_TYPE); +} + +// TEST_CASE("test load_safetensor") { +// auto array = load_safetensor("../../temp.safe"); +// } + TEST_CASE("test single array serialization") { // Basic test {