mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
expand dtype support
This commit is contained in:
parent
ba869e5e71
commit
5aa0b1f632
@ -16,6 +16,8 @@ Token Tokenizer::getToken() {
|
||||
if (!this->hasMoreTokens()) {
|
||||
return Token{TOKEN::NULL_TYPE};
|
||||
}
|
||||
// loc is not that important here, but need to increment location
|
||||
// so might as well do it all in one line
|
||||
switch (nextChar) {
|
||||
case '{':
|
||||
return Token{TOKEN::CURLY_OPEN, ++this->_loc};
|
||||
@ -34,7 +36,7 @@ Token Tokenizer::getToken() {
|
||||
while (_data[++this->_loc] != '"' && this->hasMoreTokens())
|
||||
;
|
||||
if (!this->hasMoreTokens()) {
|
||||
throw new std::runtime_error("no more chars to parse");
|
||||
throw std::runtime_error("no more chars to parse");
|
||||
}
|
||||
return Token{TOKEN::STRING, start, ++this->_loc};
|
||||
}
|
||||
@ -46,7 +48,7 @@ Token Tokenizer::getToken() {
|
||||
nextChar = this->_data[++this->_loc];
|
||||
}
|
||||
if (!this->hasMoreTokens()) {
|
||||
throw new std::runtime_error("no more chars to parse");
|
||||
throw std::runtime_error("no more chars to parse");
|
||||
}
|
||||
return Token{TOKEN::NUMBER, start, this->_loc};
|
||||
}
|
||||
@ -95,7 +97,6 @@ JSONNode parseJson(const char* data, size_t len) {
|
||||
return *obj;
|
||||
}
|
||||
if (ctx.top()->is_type(JSONNode::Type::STRING)) {
|
||||
// key is above
|
||||
auto key = ctx.top();
|
||||
ctx.pop();
|
||||
if (ctx.top()->is_type(JSONNode::Type::OBJECT)) {
|
||||
@ -159,11 +160,42 @@ JSONNode parseJson(const char* data, size_t len) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
throw std::runtime_error("[unreachable] invalid json");
|
||||
throw std::runtime_error(
|
||||
"[parseJson] json was invalid and could not be parsed");
|
||||
}
|
||||
|
||||
} // namespace io
|
||||
|
||||
Dtype dtype_from_safe_tensor_str(std::string str) {
|
||||
if (str == ST_F32) {
|
||||
return float32;
|
||||
} else if (str == ST_F16) {
|
||||
return float16;
|
||||
} else if (str == ST_BF16) {
|
||||
return bfloat16;
|
||||
} else if (str == ST_I64) {
|
||||
return int64;
|
||||
} else if (str == ST_I32) {
|
||||
return int32;
|
||||
} else if (str == ST_I16) {
|
||||
return int16;
|
||||
} else if (str == ST_I8) {
|
||||
return int8;
|
||||
} else if (str == ST_U64) {
|
||||
return uint64;
|
||||
} else if (str == ST_U32) {
|
||||
return uint32;
|
||||
} else if (str == ST_U16) {
|
||||
return uint16;
|
||||
} else if (str == ST_U8) {
|
||||
return uint8;
|
||||
} else if (str == ST_BOOL) {
|
||||
return bool_;
|
||||
} else {
|
||||
throw std::runtime_error("[safetensor] unsupported dtype " + str);
|
||||
}
|
||||
}
|
||||
|
||||
/** Load array from reader in safetensor format */
|
||||
std::unordered_map<std::string, array> load_safetensor(
|
||||
std::shared_ptr<io::Reader> in_stream,
|
||||
@ -185,12 +217,13 @@ std::unordered_map<std::string, array> load_safetensor(
|
||||
char json[jsonHeaderLength];
|
||||
in_stream->read(json, jsonHeaderLength);
|
||||
auto metadata = io::parseJson(json, jsonHeaderLength);
|
||||
// Should always be an object on the top-level
|
||||
if (!metadata.is_type(io::JSONNode::Type::OBJECT)) {
|
||||
throw std::runtime_error(
|
||||
"[load_safetensor] Invalid json metadata " + in_stream->label());
|
||||
}
|
||||
size_t offset = jsonHeaderLength + 8;
|
||||
// Parse the json raw data
|
||||
// Load the arrays using metadata
|
||||
std::unordered_map<std::string, array> res;
|
||||
for (auto& [key, obj] : *metadata.getObject()) {
|
||||
std::string dtype = obj->getObject()->at("dtype")->getString();
|
||||
@ -204,23 +237,19 @@ std::unordered_map<std::string, array> load_safetensor(
|
||||
for (const auto& offset : *data_offsets) {
|
||||
data_offsets_vec.push_back(offset->getNumber());
|
||||
}
|
||||
if (dtype == "F32") {
|
||||
auto loaded_array = array(
|
||||
shape_vec,
|
||||
float32,
|
||||
std::make_unique<Load>(
|
||||
to_stream(s),
|
||||
in_stream,
|
||||
offset + data_offsets->at(0)->getNumber(),
|
||||
offset + data_offsets->at(1)->getNumber(),
|
||||
false),
|
||||
std::vector<array>{});
|
||||
res.insert({key, loaded_array});
|
||||
}
|
||||
Dtype type = dtype_from_safe_tensor_str(dtype);
|
||||
auto loaded_array = array(
|
||||
shape_vec,
|
||||
float32,
|
||||
std::make_unique<Load>(
|
||||
to_stream(s),
|
||||
in_stream,
|
||||
offset + data_offsets->at(0)->getNumber(),
|
||||
offset + data_offsets->at(1)->getNumber(),
|
||||
false),
|
||||
std::vector<array>{});
|
||||
res.insert({key, loaded_array});
|
||||
}
|
||||
// for (auto& [key, arr] : res) {
|
||||
// arr.eval();
|
||||
// }
|
||||
return res;
|
||||
}
|
||||
|
||||
|
@ -9,7 +9,26 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
#define ST_F16 "F16"
|
||||
#define ST_BF16 "BF16"
|
||||
#define ST_F32 "F32"
|
||||
|
||||
#define ST_BOOL "BOOL"
|
||||
#define ST_I8 "I8"
|
||||
#define ST_I16 "I16"
|
||||
#define ST_I32 "I32"
|
||||
#define ST_I64 "I64"
|
||||
#define ST_U8 "U8"
|
||||
#define ST_U16 "U16"
|
||||
#define ST_U32 "U32"
|
||||
#define ST_U64 "U64"
|
||||
|
||||
namespace io {
|
||||
// NOTE: This json parser is a bare minimum implementation for safetensors,
|
||||
// it does not support all of json features, and does not have alot of edge case
|
||||
// catches. This is okay as safe tensor json is very simple and we can assume it
|
||||
// is always valid and well formed, but this should not be used for general json
|
||||
// parsing
|
||||
class JSONNode;
|
||||
using JSONObject = std::unordered_map<std::string, JSONNode*>;
|
||||
using JSONList = std::vector<JSONNode*>;
|
||||
|
Loading…
Reference in New Issue
Block a user