initial commit

This commit is contained in:
dc-dc-dc 2023-12-17 13:19:08 -05:00
parent 8385f93cea
commit 87ec7b3cf9
7 changed files with 268 additions and 0 deletions

View File

@ -11,6 +11,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
${CMAKE_CURRENT_SOURCE_DIR}/safetensor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp

View File

@ -8,6 +8,7 @@
#include "mlx/fft.h"
#include "mlx/ops.h"
#include "mlx/random.h"
#include "mlx/safetensor.h"
#include "mlx/stream.h"
#include "mlx/transforms.h"
#include "mlx/utils.h"

View File

@ -1057,4 +1057,12 @@ array dequantize(
int bits = 4,
StreamOrDevice s = {});
/** Load array from .safetensor file format */
std::map<std::string, array> load_safetensor(
std::shared_ptr<io::Reader> in_stream,
StreamOrDevice s = {});
std::map<std::string, array> load_safetensor(
const std::string& file,
StreamOrDevice s = {});
} // namespace mlx::core

160
mlx/safetensor.cpp Normal file
View File

@ -0,0 +1,160 @@
#include "mlx/safetensor.h"
#include <stack>
namespace mlx::core {
namespace io {
Token Tokenizer::getToken() {
if (!this->hasMoreTokens()) {
return Token{TOKEN::NULL_TYPE};
}
char nextChar = this->_data[this->_loc];
while ((nextChar == ' ' || nextChar == '\n') && this->hasMoreTokens()) {
nextChar = this->_data[++this->_loc];
}
if (!this->hasMoreTokens()) {
return Token{TOKEN::NULL_TYPE};
}
switch (nextChar) {
case '{':
this->_loc++;
return Token{TOKEN::CURLY_OPEN};
case '}':
this->_loc++;
return Token{TOKEN::CURLY_CLOSE};
case ':':
this->_loc++;
return Token{TOKEN::COLON};
case '[':
this->_loc++;
return Token{TOKEN::ARRAY_OPEN};
case ']':
this->_loc++;
return Token{TOKEN::ARRAY_CLOSE};
case ',':
this->_loc++;
return Token{TOKEN::COMMA};
case '"': {
size_t start = this->_loc;
this->_loc++;
while (_data[this->_loc] != '"' && this->hasMoreTokens()) {
this->_loc++;
}
if (!this->hasMoreTokens()) {
throw new std::runtime_error("no more chars to parse");
}
// pass the last "
this->_loc++;
return Token{TOKEN::STRING, start, this->_loc};
}
default: {
size_t start = this->_loc;
while ((nextChar != ',' && nextChar != '}' && nextChar != ']' &&
nextChar != ' ' && nextChar != '\n') &&
this->hasMoreTokens()) {
nextChar = this->_data[++this->_loc];
}
if (!this->hasMoreTokens()) {
throw new std::runtime_error("no more chars to parse");
}
return Token{TOKEN::NUMBER, start, this->_loc};
}
}
}
// JSONNode parseJson(char* data, size_t len) {
// auto tokenizer = Tokenizer(data, len);
// std::stack<JSONNode*> ctx;
// auto token = tokenizer.getToken();
// auto parent = new JSONNode();
// switch (token.type) {
// case TOKEN::CURLY_OPEN:
// parent->setObject(new JSONObject());
// break;
// case TOKEN::ARRAY_OPEN:
// parent->setList(new JSONList());
// break;
// default:
// throw new std::runtime_error("invalid json");
// }
// ctx.push(parent);
// while (tokenizer.hasMoreTokens()) {
// auto token = tokenizer.getToken();
// switch (token.type) {
// case TOKEN::CURLY_OPEN:
// ctx.push(new JSONNode(JSONNode::Type::OBJECT));
// break;
// case TOKEN::CURLY_CLOSE:
// if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
// auto obj = ctx.top();
// ctx.pop();
// if (ctx.top()->is_type(JSONNode::Type::LIST)) {
// auto list = ctx.top()->getList();
// list->push_back(obj);
// } else if (ctx.top()->is_type(JSONNode::Type::STRING)) {
// //
// auto key = ctx.top();
// ctx.pop();
// if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
// ctx.top()->getObject()->insert({key->getString(), obj});
// }
// }
// } else {
// throw new std::runtime_error("invalid json");
// }
// break;
// case TOKEN::COLON:
// break;
// case TOKEN::ARRAY_OPEN:
// break;
// case TOKEN::ARRAY_CLOSE:
// break;
// case TOKEN::COMMA:
// break;
// case TOKEN::NULL_TYPE:
// break;
// case TOKEN::STRING:
// break;
// case TOKEN::NUMBER:
// break;
// }
// }
// }
} // namespace io
/** Load array from reader in safetensor format */
std::map<std::string, array> load_safetensor(
std::shared_ptr<io::Reader> in_stream,
StreamOrDevice s) {
////////////////////////////////////////////////////////
// Open and check file
if (!in_stream->good() || !in_stream->is_open()) {
throw std::runtime_error(
"[load_safetensor] Failed to open " + in_stream->label());
}
uint64_t jsonHeaderLength = 0;
in_stream->read(reinterpret_cast<char*>(&jsonHeaderLength), 8);
if (jsonHeaderLength <= 0) {
throw std::runtime_error(
"[load_safetensor] Invalid json header lenght " + in_stream->label());
}
// Load the json metadata
char json[jsonHeaderLength];
in_stream->read(json, jsonHeaderLength);
// Parse the json raw data
std::map<std::string, array> res;
return res;
}
std::map<std::string, array> load_safetensor(
const std::string& file,
StreamOrDevice s) {
return load_safetensor(std::make_shared<io::FileReader>(file), s);
}
} // namespace mlx::core

79
mlx/safetensor.h Normal file
View File

@ -0,0 +1,79 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <map>
#include "mlx/load.h"
#include "mlx/ops.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
namespace io {
class JSONNode;
using JSONObject = std::map<std::string, std::shared_ptr<JSONNode>>;
using JSONList = std::vector<std::shared_ptr<JSONNode>>;
class JSONNode {
public:
enum class Type { OBJECT, LIST, STRING, NUMBER, BOOLEAN, NULL_TYPE };
JSONNode() : _type(Type::NULL_TYPE){};
JSONNode(Type type) : _type(type) {
// set the default value
if (type == Type::OBJECT) {
this->_values.object = new JSONObject();
} else if (type == Type::LIST) {
this->_values.list = new JSONList();
}
};
inline bool is_type(Type t) {
return this->_type == t;
}
private:
union Values {
JSONObject* object;
JSONList* list;
std::string* s;
float fValue;
} _values;
Type _type;
};
enum class TOKEN {
CURLY_OPEN,
CURLY_CLOSE,
COLON,
STRING,
NUMBER,
ARRAY_OPEN,
ARRAY_CLOSE,
COMMA,
NULL_TYPE,
};
struct Token {
TOKEN type;
size_t start;
size_t end;
};
class Tokenizer {
public:
Tokenizer(const char* data, size_t len) : _data(data), _loc(0), _len(len){};
Token getToken();
inline bool hasMoreTokens() {
return this->_loc < this->_len;
};
private:
const char* _data;
size_t _len;
size_t _loc;
};
} // namespace io
} // namespace mlx::core

BIN
temp.safe Normal file

Binary file not shown.

View File

@ -14,6 +14,25 @@ std::string get_temp_file(const std::string& name) {
return std::filesystem::temp_directory_path().append(name);
}
TEST_CASE("test tokenizer") {
auto raw = std::string(" { \"testing\": [1 , \"test\"]} ");
auto tokenizer = io::Tokenizer(raw.c_str(), raw.size());
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::CURLY_OPEN);
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::STRING);
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::COLON);
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::ARRAY_OPEN);
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::NUMBER);
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::COMMA);
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::STRING);
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::ARRAY_CLOSE);
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::CURLY_CLOSE);
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::NULL_TYPE);
}
// TEST_CASE("test load_safetensor") {
// auto array = load_safetensor("../../temp.safe");
// }
TEST_CASE("test single array serialization") {
// Basic test
{