expand dtype support

This commit is contained in:
dc-dc-dc 2023-12-18 15:19:04 -05:00
parent ba869e5e71
commit 5aa0b1f632
2 changed files with 69 additions and 21 deletions

View File

@ -16,6 +16,8 @@ Token Tokenizer::getToken() {
if (!this->hasMoreTokens()) {
return Token{TOKEN::NULL_TYPE};
}
// loc is not that important here, but need to increment location
// so might as well do it all in one line
switch (nextChar) {
case '{':
return Token{TOKEN::CURLY_OPEN, ++this->_loc};
@ -34,7 +36,7 @@ Token Tokenizer::getToken() {
while (_data[++this->_loc] != '"' && this->hasMoreTokens())
;
if (!this->hasMoreTokens()) {
throw new std::runtime_error("no more chars to parse");
throw std::runtime_error("no more chars to parse");
}
return Token{TOKEN::STRING, start, ++this->_loc};
}
@ -46,7 +48,7 @@ Token Tokenizer::getToken() {
nextChar = this->_data[++this->_loc];
}
if (!this->hasMoreTokens()) {
throw new std::runtime_error("no more chars to parse");
throw std::runtime_error("no more chars to parse");
}
return Token{TOKEN::NUMBER, start, this->_loc};
}
@ -95,7 +97,6 @@ JSONNode parseJson(const char* data, size_t len) {
return *obj;
}
if (ctx.top()->is_type(JSONNode::Type::STRING)) {
// key is above
auto key = ctx.top();
ctx.pop();
if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
@ -159,11 +160,42 @@ JSONNode parseJson(const char* data, size_t len) {
break;
}
}
throw std::runtime_error("[unreachable] invalid json");
throw std::runtime_error(
"[parseJson] json was invalid and could not be parsed");
}
} // namespace io
Dtype dtype_from_safe_tensor_str(std::string str) {
if (str == ST_F32) {
return float32;
} else if (str == ST_F16) {
return float16;
} else if (str == ST_BF16) {
return bfloat16;
} else if (str == ST_I64) {
return int64;
} else if (str == ST_I32) {
return int32;
} else if (str == ST_I16) {
return int16;
} else if (str == ST_I8) {
return int8;
} else if (str == ST_U64) {
return uint64;
} else if (str == ST_U32) {
return uint32;
} else if (str == ST_U16) {
return uint16;
} else if (str == ST_U8) {
return uint8;
} else if (str == ST_BOOL) {
return bool_;
} else {
throw std::runtime_error("[safetensor] unsupported dtype " + str);
}
}
/** Load array from reader in safetensor format */
std::unordered_map<std::string, array> load_safetensor(
std::shared_ptr<io::Reader> in_stream,
@ -185,12 +217,13 @@ std::unordered_map<std::string, array> load_safetensor(
char json[jsonHeaderLength];
in_stream->read(json, jsonHeaderLength);
auto metadata = io::parseJson(json, jsonHeaderLength);
// Should always be an object on the top-level
if (!metadata.is_type(io::JSONNode::Type::OBJECT)) {
throw std::runtime_error(
"[load_safetensor] Invalid json metadata " + in_stream->label());
}
size_t offset = jsonHeaderLength + 8;
// Parse the json raw data
// Load the arrays using metadata
std::unordered_map<std::string, array> res;
for (auto& [key, obj] : *metadata.getObject()) {
std::string dtype = obj->getObject()->at("dtype")->getString();
@ -204,23 +237,19 @@ std::unordered_map<std::string, array> load_safetensor(
for (const auto& offset : *data_offsets) {
data_offsets_vec.push_back(offset->getNumber());
}
if (dtype == "F32") {
auto loaded_array = array(
shape_vec,
float32,
std::make_unique<Load>(
to_stream(s),
in_stream,
offset + data_offsets->at(0)->getNumber(),
offset + data_offsets->at(1)->getNumber(),
false),
std::vector<array>{});
res.insert({key, loaded_array});
}
Dtype type = dtype_from_safe_tensor_str(dtype);
auto loaded_array = array(
shape_vec,
float32,
std::make_unique<Load>(
to_stream(s),
in_stream,
offset + data_offsets->at(0)->getNumber(),
offset + data_offsets->at(1)->getNumber(),
false),
std::vector<array>{});
res.insert({key, loaded_array});
}
// for (auto& [key, arr] : res) {
// arr.eval();
// }
return res;
}

View File

@ -9,7 +9,26 @@
namespace mlx::core {
#define ST_F16 "F16"
#define ST_BF16 "BF16"
#define ST_F32 "F32"
#define ST_BOOL "BOOL"
#define ST_I8 "I8"
#define ST_I16 "I16"
#define ST_I32 "I32"
#define ST_I64 "I64"
#define ST_U8 "U8"
#define ST_U16 "U16"
#define ST_U32 "U32"
#define ST_U64 "U64"
namespace io {
// NOTE: This json parser is a bare minimum implementation for safetensors,
// it does not support all of json features, and does not have alot of edge case
// catches. This is okay as safe tensor json is very simple and we can assume it
// is always valid and well formed, but this should not be used for general json
// parsing
class JSONNode;
using JSONObject = std::unordered_map<std::string, JSONNode*>;
using JSONList = std::vector<JSONNode*>;