mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 13:07:51 +08:00
removed custom json parser for nlohmann
This commit is contained in:
parent
99b9c1dac5
commit
472ce433f8
@ -98,6 +98,10 @@ elseif (MLX_BUILD_METAL)
|
|||||||
${QUARTZ_LIB})
|
${QUARTZ_LIB})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
MESSAGE(STATUS "Downloading json")
|
||||||
|
find_package(nlohmann_json 3.11.3 REQUIRED)
|
||||||
|
target_link_libraries(mlx nlohmann_json::nlohmann_json)
|
||||||
|
|
||||||
find_library(ACCELERATE_LIBRARY Accelerate)
|
find_library(ACCELERATE_LIBRARY Accelerate)
|
||||||
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
|
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
|
||||||
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
||||||
@ -152,6 +156,8 @@ if (MLX_BUILD_BENCHMARKS)
|
|||||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
|
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------- Installation -----------------------------
|
# ----------------------------- Installation -----------------------------
|
||||||
include(GNUInstallDirs)
|
include(GNUInstallDirs)
|
||||||
|
|
||||||
|
@ -4,200 +4,6 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
namespace io {
|
|
||||||
Token Tokenizer::getToken() {
|
|
||||||
if (!this->hasMoreTokens()) {
|
|
||||||
return Token{TOKEN::NULL_TYPE};
|
|
||||||
}
|
|
||||||
char nextChar = this->_data[this->_loc];
|
|
||||||
while ((nextChar == ' ' || nextChar == '\n') && this->hasMoreTokens()) {
|
|
||||||
nextChar = this->_data[++this->_loc];
|
|
||||||
}
|
|
||||||
if (!this->hasMoreTokens()) {
|
|
||||||
return Token{TOKEN::NULL_TYPE};
|
|
||||||
}
|
|
||||||
// loc is not that important here, but need to increment location
|
|
||||||
// so might as well do it all in one line
|
|
||||||
switch (nextChar) {
|
|
||||||
case '{':
|
|
||||||
return Token{TOKEN::CURLY_OPEN, ++this->_loc};
|
|
||||||
case '}':
|
|
||||||
return Token{TOKEN::CURLY_CLOSE, ++this->_loc};
|
|
||||||
case ':':
|
|
||||||
return Token{TOKEN::COLON, ++this->_loc};
|
|
||||||
case '[':
|
|
||||||
return Token{TOKEN::ARRAY_OPEN, ++this->_loc};
|
|
||||||
case ']':
|
|
||||||
return Token{TOKEN::ARRAY_CLOSE, ++this->_loc};
|
|
||||||
case ',':
|
|
||||||
return Token{TOKEN::COMMA, ++this->_loc};
|
|
||||||
case '"': {
|
|
||||||
size_t start = ++this->_loc;
|
|
||||||
while (_data[++this->_loc] != '"' && this->hasMoreTokens())
|
|
||||||
;
|
|
||||||
if (!this->hasMoreTokens()) {
|
|
||||||
throw std::runtime_error("no more chars to parse");
|
|
||||||
}
|
|
||||||
return Token{TOKEN::STRING, start, ++this->_loc};
|
|
||||||
}
|
|
||||||
default: {
|
|
||||||
size_t start = this->_loc;
|
|
||||||
while ((nextChar != ',' && nextChar != '}' && nextChar != ']' &&
|
|
||||||
nextChar != ' ' && nextChar != '\n') &&
|
|
||||||
this->hasMoreTokens()) {
|
|
||||||
nextChar = this->_data[++this->_loc];
|
|
||||||
}
|
|
||||||
if (!this->hasMoreTokens()) {
|
|
||||||
throw std::runtime_error("no more chars to parse");
|
|
||||||
}
|
|
||||||
return Token{TOKEN::NUMBER, start, this->_loc};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
JSONNode jsonDeserialize(const char* data, size_t len) {
|
|
||||||
auto tokenizer = Tokenizer(data, len);
|
|
||||||
std::stack<JSONNode*> ctx;
|
|
||||||
while (tokenizer.hasMoreTokens()) {
|
|
||||||
auto token = tokenizer.getToken();
|
|
||||||
switch (token.type) {
|
|
||||||
case TOKEN::CURLY_OPEN:
|
|
||||||
ctx.push(new JSONNode(JSONNode::Type::OBJECT));
|
|
||||||
break;
|
|
||||||
case TOKEN::ARRAY_OPEN:
|
|
||||||
ctx.push(new JSONNode(JSONNode::Type::LIST));
|
|
||||||
break;
|
|
||||||
case TOKEN::CURLY_CLOSE:
|
|
||||||
if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
|
|
||||||
auto obj = ctx.top();
|
|
||||||
ctx.pop();
|
|
||||||
// top-level object
|
|
||||||
if (ctx.size() == 0) {
|
|
||||||
return *obj;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ctx.top()->is_type(JSONNode::Type::STRING)) {
|
|
||||||
auto key = ctx.top();
|
|
||||||
ctx.pop();
|
|
||||||
if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
|
|
||||||
ctx.top()->getObject()->insert({key->getString(), obj});
|
|
||||||
} else {
|
|
||||||
throw std::runtime_error("invalid json");
|
|
||||||
}
|
|
||||||
} else if (ctx.top()->is_type(JSONNode::Type::LIST)) {
|
|
||||||
ctx.top()->getList()->push_back(obj);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case TOKEN::ARRAY_CLOSE:
|
|
||||||
if (ctx.top()->is_type(JSONNode::Type::LIST)) {
|
|
||||||
auto obj = ctx.top();
|
|
||||||
ctx.pop();
|
|
||||||
if (ctx.size() == 0) {
|
|
||||||
return *obj;
|
|
||||||
}
|
|
||||||
if (ctx.top()->is_type(JSONNode::Type::STRING)) {
|
|
||||||
auto key = ctx.top();
|
|
||||||
ctx.pop();
|
|
||||||
if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
|
|
||||||
ctx.top()->getObject()->insert({key->getString(), obj});
|
|
||||||
} else {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"invalid json, string/array key pair did not have object parent");
|
|
||||||
}
|
|
||||||
} else if (ctx.top()->is_type(JSONNode::Type::LIST)) {
|
|
||||||
if (ctx.top()->is_type(JSONNode::Type::LIST)) {
|
|
||||||
ctx.top()->getList()->push_back(obj);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"invalid json, could not find array to close");
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case TOKEN::STRING: {
|
|
||||||
auto str =
|
|
||||||
new std::string(data + token.start, token.end - token.start - 1);
|
|
||||||
if (ctx.top()->is_type(JSONNode::Type::LIST)) {
|
|
||||||
ctx.top()->getList()->push_back(new JSONNode(str));
|
|
||||||
} else if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
|
|
||||||
ctx.push(new JSONNode(str));
|
|
||||||
} else if (ctx.top()->is_type(JSONNode::Type::STRING)) {
|
|
||||||
auto key = ctx.top();
|
|
||||||
ctx.pop();
|
|
||||||
if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
|
|
||||||
ctx.top()->getObject()->insert(
|
|
||||||
{key->getString(), new JSONNode(str)});
|
|
||||||
} else {
|
|
||||||
throw std::runtime_error("invalid json");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case TOKEN::NUMBER: {
|
|
||||||
// TODO: is there an easier way of doing this.
|
|
||||||
auto str = new std::string(data + token.start, token.end - token.start);
|
|
||||||
auto val = strtoul(str->c_str(), nullptr, 10);
|
|
||||||
if (ctx.top()->is_type(JSONNode::Type::LIST)) {
|
|
||||||
ctx.top()->getList()->push_back(new JSONNode(val));
|
|
||||||
} else if (ctx.top()->is_type(JSONNode::Type::STRING)) {
|
|
||||||
auto key = ctx.top();
|
|
||||||
ctx.pop();
|
|
||||||
if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
|
|
||||||
ctx.top()->getObject()->insert(
|
|
||||||
{key->getString(), new JSONNode(val)});
|
|
||||||
} else {
|
|
||||||
throw std::runtime_error("invalid json");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
throw std::runtime_error(
|
|
||||||
"[jsonDeserialize] json was invalid and could not be parsed");
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string jsonSerialize(JSONNode* node) {
|
|
||||||
std::string res;
|
|
||||||
if (node->is_type(JSONNode::Type::STRING)) {
|
|
||||||
return "\"" + node->getString() + "\"";
|
|
||||||
}
|
|
||||||
if (node->is_type(JSONNode::Type::NUMBER)) {
|
|
||||||
return std::to_string(node->getNumber());
|
|
||||||
}
|
|
||||||
if (node->is_type(JSONNode::Type::LIST)) {
|
|
||||||
res += "[";
|
|
||||||
for (auto& item : *node->getList()) {
|
|
||||||
res += jsonSerialize(item);
|
|
||||||
res += ",";
|
|
||||||
}
|
|
||||||
if (res.back() == ',') {
|
|
||||||
res.pop_back();
|
|
||||||
}
|
|
||||||
res += "]";
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
if (node->is_type(JSONNode::Type::OBJECT)) {
|
|
||||||
res += "{";
|
|
||||||
for (auto& [key, item] : *node->getObject()) {
|
|
||||||
res += "\"" + key + "\":";
|
|
||||||
res += jsonSerialize(item);
|
|
||||||
res += ",";
|
|
||||||
}
|
|
||||||
if (res.back() == ',') {
|
|
||||||
res.pop_back();
|
|
||||||
}
|
|
||||||
res += "}";
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
throw std::runtime_error("[jsonSerialize] invalid json node");
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace io
|
|
||||||
std::string dtype_to_safetensor_str(Dtype t) {
|
std::string dtype_to_safetensor_str(Dtype t) {
|
||||||
if (t == float32) {
|
if (t == float32) {
|
||||||
return ST_F32;
|
return ST_F32;
|
||||||
@ -276,44 +82,34 @@ std::unordered_map<std::string, array> load_safetensor(
|
|||||||
"[load_safetensor] Invalid json header length " + in_stream->label());
|
"[load_safetensor] Invalid json header length " + in_stream->label());
|
||||||
}
|
}
|
||||||
// Load the json metadata
|
// Load the json metadata
|
||||||
char json[jsonHeaderLength];
|
char rawJson[jsonHeaderLength];
|
||||||
in_stream->read(json, jsonHeaderLength);
|
in_stream->read(rawJson, jsonHeaderLength);
|
||||||
auto metadata = io::jsonDeserialize(json, jsonHeaderLength);
|
auto metadata = json::parse(rawJson, rawJson + jsonHeaderLength);
|
||||||
|
// auto metadata = io::jsonDeserialize(json, jsonHeaderLength);
|
||||||
// Should always be an object on the top-level
|
// Should always be an object on the top-level
|
||||||
if (!metadata.is_type(io::JSONNode::Type::OBJECT)) {
|
if (!metadata.is_object()) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[load_safetensor] Invalid json metadata " + in_stream->label());
|
"[load_safetensor] Invalid json metadata " + in_stream->label());
|
||||||
}
|
}
|
||||||
size_t offset = jsonHeaderLength + 8;
|
size_t offset = jsonHeaderLength + 8;
|
||||||
// Load the arrays using metadata
|
// Load the arrays using metadata
|
||||||
std::unordered_map<std::string, array> res;
|
std::unordered_map<std::string, array> res;
|
||||||
for (auto& [key, obj] : *metadata.getObject()) {
|
for (const auto& item : metadata.items()) {
|
||||||
if (key == "__metadata__") {
|
if (item.key() == "__metadata__") {
|
||||||
// ignore metadata for now
|
// ignore metadata for now
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
std::string dtype = obj->getObject()->at("dtype")->getString();
|
std::string dtype = item.value().at("dtype");
|
||||||
auto shape = obj->getObject()->at("shape")->getList();
|
std::vector<int> shape = item.value().at("shape");
|
||||||
std::vector<int> shape_vec;
|
std::vector<size_t> data_offsets = item.value().at("data_offsets");
|
||||||
for (const auto& dim : *shape) {
|
|
||||||
shape_vec.push_back(dim->getNumber());
|
|
||||||
}
|
|
||||||
auto data_offsets = obj->getObject()->at("data_offsets")->getList();
|
|
||||||
std::vector<int64_t> data_offsets_vec;
|
|
||||||
for (const auto& offset : *data_offsets) {
|
|
||||||
data_offsets_vec.push_back(offset->getNumber());
|
|
||||||
}
|
|
||||||
Dtype type = dtype_from_safetensor_str(dtype);
|
Dtype type = dtype_from_safetensor_str(dtype);
|
||||||
auto loaded_array = array(
|
auto loaded_array = array(
|
||||||
shape_vec,
|
shape,
|
||||||
type,
|
type,
|
||||||
std::make_unique<Load>(
|
std::make_unique<Load>(
|
||||||
to_stream(s),
|
to_stream(s), in_stream, offset + data_offsets.at(0), false),
|
||||||
in_stream,
|
|
||||||
offset + data_offsets->at(0)->getNumber(),
|
|
||||||
false),
|
|
||||||
std::vector<array>{});
|
std::vector<array>{});
|
||||||
res.insert({key, loaded_array});
|
res.insert({item.key(), loaded_array});
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
@ -330,8 +126,8 @@ void save_safetensor(
|
|||||||
std::unordered_map<std::string, array> a) {
|
std::unordered_map<std::string, array> a) {
|
||||||
////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////
|
||||||
// Check array map
|
// Check array map
|
||||||
|
json parent;
|
||||||
|
|
||||||
io::JSONNode metadata(io::JSONNode::Type::OBJECT);
|
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
for (auto& [key, arr] : a) {
|
for (auto& [key, arr] : a) {
|
||||||
arr.eval(false);
|
arr.eval(false);
|
||||||
@ -345,29 +141,12 @@ void save_safetensor(
|
|||||||
"[save_safetensor] cannot serialize a non-contiguous array key: " +
|
"[save_safetensor] cannot serialize a non-contiguous array key: " +
|
||||||
key);
|
key);
|
||||||
}
|
}
|
||||||
auto obj = new io::JSONNode(io::JSONNode::Type::OBJECT);
|
json child;
|
||||||
// TODO: dont make a new string
|
// TODO: dont make a new string
|
||||||
obj->getObject()->insert(
|
child["dtype"] = dtype_to_safetensor_str(arr.dtype());
|
||||||
{"dtype",
|
child["shape"] = arr.shape();
|
||||||
new io::JSONNode(
|
child["data_offsets"] = std::vector<size_t>{offset, offset + arr.nbytes()};
|
||||||
new std::string(dtype_to_safetensor_str(arr.dtype())))});
|
parent[key] = child;
|
||||||
obj->getObject()->insert(
|
|
||||||
{"shape", new io::JSONNode(io::JSONNode::Type::LIST)});
|
|
||||||
for (auto& dim : arr.shape()) {
|
|
||||||
obj->getObject()->at("shape")->getList()->push_back(
|
|
||||||
new io::JSONNode(dim));
|
|
||||||
}
|
|
||||||
obj->getObject()->insert(
|
|
||||||
{"data_offsets", new io::JSONNode(io::JSONNode::Type::LIST)});
|
|
||||||
obj->getObject()
|
|
||||||
->at("data_offsets")
|
|
||||||
->getList()
|
|
||||||
->push_back(new io::JSONNode(offset));
|
|
||||||
obj->getObject()
|
|
||||||
->at("data_offsets")
|
|
||||||
->getList()
|
|
||||||
->push_back(new io::JSONNode(offset + arr.nbytes()));
|
|
||||||
metadata.getObject()->insert({key, obj});
|
|
||||||
offset += arr.nbytes();
|
offset += arr.nbytes();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -378,7 +157,7 @@ void save_safetensor(
|
|||||||
"[save_safetensor] Failed to open " + out_stream->label());
|
"[save_safetensor] Failed to open " + out_stream->label());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto header = io::jsonSerialize(&metadata);
|
auto header = parent.dump();
|
||||||
uint64_t header_len = header.length();
|
uint64_t header_len = header.length();
|
||||||
out_stream->write(reinterpret_cast<char*>(&header_len), 8);
|
out_stream->write(reinterpret_cast<char*>(&header_len), 8);
|
||||||
out_stream->write(header.c_str(), header_len);
|
out_stream->write(header.c_str(), header_len);
|
||||||
@ -393,7 +172,7 @@ void save_safetensor(
|
|||||||
// Open and check file
|
// Open and check file
|
||||||
std::string file = file_;
|
std::string file = file_;
|
||||||
|
|
||||||
// Add .npy to file name if it is not there
|
// Add .safetensors to file name if it is not there
|
||||||
if (file.length() < 12 ||
|
if (file.length() < 12 ||
|
||||||
file.substr(file.length() - 12, 12) != ".safetensors")
|
file.substr(file.length() - 12, 12) != ".safetensors")
|
||||||
file += ".safetensors";
|
file += ".safetensors";
|
||||||
|
116
mlx/safetensor.h
116
mlx/safetensor.h
@ -2,11 +2,15 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
#include "mlx/load.h"
|
#include "mlx/load.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
using json = nlohmann::json;
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
#define ST_F16 "F16"
|
#define ST_F16 "F16"
|
||||||
@ -22,116 +26,4 @@ namespace mlx::core {
|
|||||||
#define ST_U16 "U16"
|
#define ST_U16 "U16"
|
||||||
#define ST_U32 "U32"
|
#define ST_U32 "U32"
|
||||||
#define ST_U64 "U64"
|
#define ST_U64 "U64"
|
||||||
|
|
||||||
namespace io {
|
|
||||||
// NOTE: This json parser is a bare minimum implementation for safetensors,
|
|
||||||
// it does not support all of json features, and does not have alot of edge case
|
|
||||||
// catches. This is okay as safe tensor json is very simple and we can assume it
|
|
||||||
// is always valid and well formed, but this should not be used for general json
|
|
||||||
// parsing
|
|
||||||
class JSONNode;
|
|
||||||
using JSONObject = std::unordered_map<std::string, JSONNode*>;
|
|
||||||
using JSONList = std::vector<JSONNode*>;
|
|
||||||
|
|
||||||
class JSONNode {
|
|
||||||
public:
|
|
||||||
enum class Type { OBJECT, LIST, STRING, NUMBER, NULL_TYPE };
|
|
||||||
|
|
||||||
JSONNode() : _type(Type::NULL_TYPE){};
|
|
||||||
JSONNode(Type type) : _type(type) {
|
|
||||||
// set the default value
|
|
||||||
if (type == Type::OBJECT) {
|
|
||||||
this->_values.object = new JSONObject();
|
|
||||||
} else if (type == Type::LIST) {
|
|
||||||
this->_values.list = new JSONList();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
JSONNode(std::string* s) : _type(Type::STRING) {
|
|
||||||
this->_values.s = s;
|
|
||||||
};
|
|
||||||
JSONNode(float f) : _type(Type::NUMBER) {
|
|
||||||
this->_values.f = f;
|
|
||||||
};
|
|
||||||
|
|
||||||
JSONObject* getObject() {
|
|
||||||
if (!is_type(Type::OBJECT)) {
|
|
||||||
throw new std::runtime_error("not an object");
|
|
||||||
}
|
|
||||||
return this->_values.object;
|
|
||||||
}
|
|
||||||
|
|
||||||
JSONList* getList() {
|
|
||||||
if (!is_type(Type::LIST)) {
|
|
||||||
throw new std::runtime_error("not a list");
|
|
||||||
}
|
|
||||||
return this->_values.list;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string getString() {
|
|
||||||
if (!is_type(Type::STRING)) {
|
|
||||||
throw new std::runtime_error("not a string");
|
|
||||||
}
|
|
||||||
return *this->_values.s;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t getNumber() {
|
|
||||||
if (!is_type(Type::NUMBER)) {
|
|
||||||
throw new std::runtime_error("not a number");
|
|
||||||
}
|
|
||||||
return this->_values.f;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool is_type(Type t) {
|
|
||||||
return this->_type == t;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline Type type() const {
|
|
||||||
return this->_type;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
union Values {
|
|
||||||
JSONObject* object;
|
|
||||||
JSONList* list;
|
|
||||||
std::string* s;
|
|
||||||
uint64_t f;
|
|
||||||
} _values;
|
|
||||||
Type _type;
|
|
||||||
};
|
|
||||||
|
|
||||||
JSONNode jsonDeserialize(const char* data, size_t len);
|
|
||||||
std::string jsonSerialize(JSONNode* node);
|
|
||||||
|
|
||||||
enum class TOKEN {
|
|
||||||
CURLY_OPEN,
|
|
||||||
CURLY_CLOSE,
|
|
||||||
COLON,
|
|
||||||
STRING,
|
|
||||||
NUMBER,
|
|
||||||
ARRAY_OPEN,
|
|
||||||
ARRAY_CLOSE,
|
|
||||||
COMMA,
|
|
||||||
NULL_TYPE,
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Token {
|
|
||||||
TOKEN type;
|
|
||||||
size_t start;
|
|
||||||
size_t end;
|
|
||||||
};
|
|
||||||
|
|
||||||
class Tokenizer {
|
|
||||||
public:
|
|
||||||
Tokenizer(const char* data, size_t len) : _data(data), _loc(0), _len(len){};
|
|
||||||
Token getToken();
|
|
||||||
inline bool hasMoreTokens() {
|
|
||||||
return this->_loc < this->_len;
|
|
||||||
};
|
|
||||||
|
|
||||||
private:
|
|
||||||
const char* _data;
|
|
||||||
size_t _len;
|
|
||||||
size_t _loc;
|
|
||||||
};
|
|
||||||
} // namespace io
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
@ -14,152 +14,6 @@ std::string get_temp_file(const std::string& name) {
|
|||||||
return std::filesystem::temp_directory_path().append(name);
|
return std::filesystem::temp_directory_path().append(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test tokenizer") {
|
|
||||||
auto raw = std::string(" { \"testing\": [1 , \"test\"]} ");
|
|
||||||
auto tokenizer = io::Tokenizer(raw.c_str(), raw.size());
|
|
||||||
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::CURLY_OPEN);
|
|
||||||
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::STRING);
|
|
||||||
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::COLON);
|
|
||||||
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::ARRAY_OPEN);
|
|
||||||
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::NUMBER);
|
|
||||||
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::COMMA);
|
|
||||||
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::STRING);
|
|
||||||
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::ARRAY_CLOSE);
|
|
||||||
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::CURLY_CLOSE);
|
|
||||||
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::NULL_TYPE);
|
|
||||||
|
|
||||||
raw = std::string(" { \"testing\": \"test\"} ");
|
|
||||||
tokenizer = io::Tokenizer(raw.c_str(), raw.size());
|
|
||||||
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::CURLY_OPEN);
|
|
||||||
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::STRING);
|
|
||||||
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::COLON);
|
|
||||||
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::STRING);
|
|
||||||
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::CURLY_CLOSE);
|
|
||||||
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::NULL_TYPE);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_CASE("test jsonSerialize") {
|
|
||||||
auto test = new io::JSONNode(io::JSONNode::Type::OBJECT);
|
|
||||||
auto src = io::jsonSerialize(test);
|
|
||||||
CHECK_EQ(src, "{}");
|
|
||||||
test = new io::JSONNode(io::JSONNode::Type::LIST);
|
|
||||||
src = io::jsonSerialize(test);
|
|
||||||
CHECK_EQ(src, "[]");
|
|
||||||
test = new io::JSONNode(io::JSONNode::Type::OBJECT);
|
|
||||||
test->getObject()->insert(
|
|
||||||
{"test", new io::JSONNode(new std::string("testing"))});
|
|
||||||
src = io::jsonSerialize(test);
|
|
||||||
CHECK_EQ(src, "{\"test\":\"testing\"}");
|
|
||||||
test = new io::JSONNode(io::JSONNode::Type::OBJECT);
|
|
||||||
auto arr = new io::JSONNode(io::JSONNode::Type::LIST);
|
|
||||||
arr->getList()->push_back(new io::JSONNode(1));
|
|
||||||
arr->getList()->push_back(new io::JSONNode(2));
|
|
||||||
test->getObject()->insert({"test", arr});
|
|
||||||
src = io::jsonSerialize(test);
|
|
||||||
CHECK_EQ(src, "{\"test\":[1,2]}");
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_CASE("test jsonDeserialize") {
|
|
||||||
auto raw = std::string("{}");
|
|
||||||
auto res = io::jsonDeserialize(raw.c_str(), raw.size());
|
|
||||||
CHECK(res.is_type(io::JSONNode::Type::OBJECT));
|
|
||||||
|
|
||||||
raw = std::string("[]");
|
|
||||||
res = io::jsonDeserialize(raw.c_str(), raw.size());
|
|
||||||
CHECK(res.is_type(io::JSONNode::Type::LIST));
|
|
||||||
|
|
||||||
raw = std::string("[");
|
|
||||||
CHECK_THROWS_AS(
|
|
||||||
io::jsonDeserialize(raw.c_str(), raw.size()), std::runtime_error);
|
|
||||||
|
|
||||||
raw = std::string("[{}, \"test\"]");
|
|
||||||
res = io::jsonDeserialize(raw.c_str(), raw.size());
|
|
||||||
CHECK(res.is_type(io::JSONNode::Type::LIST));
|
|
||||||
CHECK_EQ(res.getList()->size(), 2);
|
|
||||||
CHECK(res.getList()->at(0)->is_type(io::JSONNode::Type::OBJECT));
|
|
||||||
CHECK(res.getList()->at(1)->is_type(io::JSONNode::Type::STRING));
|
|
||||||
CHECK_EQ(res.getList()->at(1)->getString(), "test");
|
|
||||||
|
|
||||||
raw = std::string(
|
|
||||||
"{\"test\":{\"dtype\":\"F32\",\"shape\":[4], \"data_offsets\":[0, 16]}}");
|
|
||||||
res = io::jsonDeserialize(raw.c_str(), raw.size());
|
|
||||||
CHECK(res.is_type(io::JSONNode::Type::OBJECT));
|
|
||||||
CHECK_EQ(res.getObject()->size(), 1);
|
|
||||||
CHECK(res.getObject()->at("test")->is_type(io::JSONNode::Type::OBJECT));
|
|
||||||
CHECK_EQ(res.getObject()->at("test")->getObject()->size(), 3);
|
|
||||||
CHECK(res.getObject()->at("test")->getObject()->at("dtype")->is_type(
|
|
||||||
io::JSONNode::Type::STRING));
|
|
||||||
CHECK_EQ(
|
|
||||||
res.getObject()->at("test")->getObject()->at("dtype")->getString(),
|
|
||||||
"F32");
|
|
||||||
CHECK(res.getObject()->at("test")->getObject()->at("shape")->is_type(
|
|
||||||
io::JSONNode::Type::LIST));
|
|
||||||
CHECK_EQ(
|
|
||||||
res.getObject()->at("test")->getObject()->at("shape")->getList()->size(),
|
|
||||||
1);
|
|
||||||
CHECK(res.getObject()
|
|
||||||
->at("test")
|
|
||||||
->getObject()
|
|
||||||
->at("shape")
|
|
||||||
->getList()
|
|
||||||
->at(0)
|
|
||||||
->is_type(io::JSONNode::Type::NUMBER));
|
|
||||||
CHECK_EQ(
|
|
||||||
res.getObject()
|
|
||||||
->at("test")
|
|
||||||
->getObject()
|
|
||||||
->at("shape")
|
|
||||||
->getList()
|
|
||||||
->at(0)
|
|
||||||
->getNumber(),
|
|
||||||
4);
|
|
||||||
CHECK(res.getObject()
|
|
||||||
->at("test")
|
|
||||||
->getObject()
|
|
||||||
->at("data_offsets")
|
|
||||||
->is_type(io::JSONNode::Type::LIST));
|
|
||||||
CHECK_EQ(
|
|
||||||
res.getObject()
|
|
||||||
->at("test")
|
|
||||||
->getObject()
|
|
||||||
->at("data_offsets")
|
|
||||||
->getList()
|
|
||||||
->size(),
|
|
||||||
2);
|
|
||||||
CHECK(res.getObject()
|
|
||||||
->at("test")
|
|
||||||
->getObject()
|
|
||||||
->at("data_offsets")
|
|
||||||
->getList()
|
|
||||||
->at(0)
|
|
||||||
->is_type(io::JSONNode::Type::NUMBER));
|
|
||||||
CHECK_EQ(
|
|
||||||
res.getObject()
|
|
||||||
->at("test")
|
|
||||||
->getObject()
|
|
||||||
->at("data_offsets")
|
|
||||||
->getList()
|
|
||||||
->at(0)
|
|
||||||
->getNumber(),
|
|
||||||
0);
|
|
||||||
CHECK(res.getObject()
|
|
||||||
->at("test")
|
|
||||||
->getObject()
|
|
||||||
->at("data_offsets")
|
|
||||||
->getList()
|
|
||||||
->at(1)
|
|
||||||
->is_type(io::JSONNode::Type::NUMBER));
|
|
||||||
CHECK_EQ(
|
|
||||||
res.getObject()
|
|
||||||
->at("test")
|
|
||||||
->getObject()
|
|
||||||
->at("data_offsets")
|
|
||||||
->getList()
|
|
||||||
->at(1)
|
|
||||||
->getNumber(),
|
|
||||||
16);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_CASE("test save_safetensor") {
|
TEST_CASE("test save_safetensor") {
|
||||||
std::string file_path = get_temp_file("test_arr.safetensors");
|
std::string file_path = get_temp_file("test_arr.safetensors");
|
||||||
auto map = std::unordered_map<std::string, array>();
|
auto map = std::unordered_map<std::string, array>();
|
||||||
|
Loading…
Reference in New Issue
Block a user