expand dtype support

This commit is contained in:
dc-dc-dc 2023-12-18 15:19:04 -05:00
parent ba869e5e71
commit 5aa0b1f632
2 changed files with 69 additions and 21 deletions

View File

@ -16,6 +16,8 @@ Token Tokenizer::getToken() {
if (!this->hasMoreTokens()) { if (!this->hasMoreTokens()) {
return Token{TOKEN::NULL_TYPE}; return Token{TOKEN::NULL_TYPE};
} }
// loc is not that important here, but need to increment location
// so might as well do it all in one line
switch (nextChar) { switch (nextChar) {
case '{': case '{':
return Token{TOKEN::CURLY_OPEN, ++this->_loc}; return Token{TOKEN::CURLY_OPEN, ++this->_loc};
@ -34,7 +36,7 @@ Token Tokenizer::getToken() {
while (_data[++this->_loc] != '"' && this->hasMoreTokens()) while (_data[++this->_loc] != '"' && this->hasMoreTokens())
; ;
if (!this->hasMoreTokens()) { if (!this->hasMoreTokens()) {
throw new std::runtime_error("no more chars to parse"); throw std::runtime_error("no more chars to parse");
} }
return Token{TOKEN::STRING, start, ++this->_loc}; return Token{TOKEN::STRING, start, ++this->_loc};
} }
@ -46,7 +48,7 @@ Token Tokenizer::getToken() {
nextChar = this->_data[++this->_loc]; nextChar = this->_data[++this->_loc];
} }
if (!this->hasMoreTokens()) { if (!this->hasMoreTokens()) {
throw new std::runtime_error("no more chars to parse"); throw std::runtime_error("no more chars to parse");
} }
return Token{TOKEN::NUMBER, start, this->_loc}; return Token{TOKEN::NUMBER, start, this->_loc};
} }
@ -95,7 +97,6 @@ JSONNode parseJson(const char* data, size_t len) {
return *obj; return *obj;
} }
if (ctx.top()->is_type(JSONNode::Type::STRING)) { if (ctx.top()->is_type(JSONNode::Type::STRING)) {
// key is above
auto key = ctx.top(); auto key = ctx.top();
ctx.pop(); ctx.pop();
if (ctx.top()->is_type(JSONNode::Type::OBJECT)) { if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
@ -159,11 +160,42 @@ JSONNode parseJson(const char* data, size_t len) {
break; break;
} }
} }
throw std::runtime_error("[unreachable] invalid json"); throw std::runtime_error(
"[parseJson] json was invalid and could not be parsed");
} }
} // namespace io } // namespace io
Dtype dtype_from_safe_tensor_str(std::string str) {
if (str == ST_F32) {
return float32;
} else if (str == ST_F16) {
return float16;
} else if (str == ST_BF16) {
return bfloat16;
} else if (str == ST_I64) {
return int64;
} else if (str == ST_I32) {
return int32;
} else if (str == ST_I16) {
return int16;
} else if (str == ST_I8) {
return int8;
} else if (str == ST_U64) {
return uint64;
} else if (str == ST_U32) {
return uint32;
} else if (str == ST_U16) {
return uint16;
} else if (str == ST_U8) {
return uint8;
} else if (str == ST_BOOL) {
return bool_;
} else {
throw std::runtime_error("[safetensor] unsupported dtype " + str);
}
}
/** Load array from reader in safetensor format */ /** Load array from reader in safetensor format */
std::unordered_map<std::string, array> load_safetensor( std::unordered_map<std::string, array> load_safetensor(
std::shared_ptr<io::Reader> in_stream, std::shared_ptr<io::Reader> in_stream,
@ -185,12 +217,13 @@ std::unordered_map<std::string, array> load_safetensor(
char json[jsonHeaderLength]; char json[jsonHeaderLength];
in_stream->read(json, jsonHeaderLength); in_stream->read(json, jsonHeaderLength);
auto metadata = io::parseJson(json, jsonHeaderLength); auto metadata = io::parseJson(json, jsonHeaderLength);
// Should always be an object on the top-level
if (!metadata.is_type(io::JSONNode::Type::OBJECT)) { if (!metadata.is_type(io::JSONNode::Type::OBJECT)) {
throw std::runtime_error( throw std::runtime_error(
"[load_safetensor] Invalid json metadata " + in_stream->label()); "[load_safetensor] Invalid json metadata " + in_stream->label());
} }
size_t offset = jsonHeaderLength + 8; size_t offset = jsonHeaderLength + 8;
// Parse the json raw data // Load the arrays using metadata
std::unordered_map<std::string, array> res; std::unordered_map<std::string, array> res;
for (auto& [key, obj] : *metadata.getObject()) { for (auto& [key, obj] : *metadata.getObject()) {
std::string dtype = obj->getObject()->at("dtype")->getString(); std::string dtype = obj->getObject()->at("dtype")->getString();
@ -204,7 +237,7 @@ std::unordered_map<std::string, array> load_safetensor(
for (const auto& offset : *data_offsets) { for (const auto& offset : *data_offsets) {
data_offsets_vec.push_back(offset->getNumber()); data_offsets_vec.push_back(offset->getNumber());
} }
if (dtype == "F32") { Dtype type = dtype_from_safe_tensor_str(dtype);
auto loaded_array = array( auto loaded_array = array(
shape_vec, shape_vec,
float32, float32,
@ -217,10 +250,6 @@ std::unordered_map<std::string, array> load_safetensor(
std::vector<array>{}); std::vector<array>{});
res.insert({key, loaded_array}); res.insert({key, loaded_array});
} }
}
// for (auto& [key, arr] : res) {
// arr.eval();
// }
return res; return res;
} }

View File

@ -9,7 +9,26 @@
namespace mlx::core { namespace mlx::core {
#define ST_F16 "F16"
#define ST_BF16 "BF16"
#define ST_F32 "F32"
#define ST_BOOL "BOOL"
#define ST_I8 "I8"
#define ST_I16 "I16"
#define ST_I32 "I32"
#define ST_I64 "I64"
#define ST_U8 "U8"
#define ST_U16 "U16"
#define ST_U32 "U32"
#define ST_U64 "U64"
namespace io { namespace io {
// NOTE: This json parser is a bare minimum implementation for safetensors,
// it does not support all of json features, and does not have alot of edge case
// catches. This is okay as safe tensor json is very simple and we can assume it
// is always valid and well formed, but this should not be used for general json
// parsing
class JSONNode; class JSONNode;
using JSONObject = std::unordered_map<std::string, JSONNode*>; using JSONObject = std::unordered_map<std::string, JSONNode*>;
using JSONList = std::vector<JSONNode*>; using JSONList = std::vector<JSONNode*>;