From 5aa0b1f632f8ebbed4e5fc5a89cf57e028dd0a97 Mon Sep 17 00:00:00 2001 From: dc-dc-dc Date: Mon, 18 Dec 2023 15:19:04 -0500 Subject: [PATCH] expand dtype support --- mlx/safetensor.cpp | 71 ++++++++++++++++++++++++++++++++-------------- mlx/safetensor.h | 19 +++++++++++++ 2 files changed, 69 insertions(+), 21 deletions(-) diff --git a/mlx/safetensor.cpp b/mlx/safetensor.cpp index a6c0ce69d..410255627 100644 --- a/mlx/safetensor.cpp +++ b/mlx/safetensor.cpp @@ -16,6 +16,8 @@ Token Tokenizer::getToken() { if (!this->hasMoreTokens()) { return Token{TOKEN::NULL_TYPE}; } + // loc is not that important here, but need to increment location + // so might as well do it all in one line switch (nextChar) { case '{': return Token{TOKEN::CURLY_OPEN, ++this->_loc}; @@ -34,7 +36,7 @@ Token Tokenizer::getToken() { while (_data[++this->_loc] != '"' && this->hasMoreTokens()) ; if (!this->hasMoreTokens()) { - throw new std::runtime_error("no more chars to parse"); + throw std::runtime_error("no more chars to parse"); } return Token{TOKEN::STRING, start, ++this->_loc}; } @@ -46,7 +48,7 @@ Token Tokenizer::getToken() { nextChar = this->_data[++this->_loc]; } if (!this->hasMoreTokens()) { - throw new std::runtime_error("no more chars to parse"); + throw std::runtime_error("no more chars to parse"); } return Token{TOKEN::NUMBER, start, this->_loc}; } @@ -95,7 +97,6 @@ JSONNode parseJson(const char* data, size_t len) { return *obj; } if (ctx.top()->is_type(JSONNode::Type::STRING)) { - // key is above auto key = ctx.top(); ctx.pop(); if (ctx.top()->is_type(JSONNode::Type::OBJECT)) { @@ -159,11 +160,42 @@ JSONNode parseJson(const char* data, size_t len) { break; } } - throw std::runtime_error("[unreachable] invalid json"); + throw std::runtime_error( + "[parseJson] json was invalid and could not be parsed"); } } // namespace io +Dtype dtype_from_safe_tensor_str(std::string str) { + if (str == ST_F32) { + return float32; + } else if (str == ST_F16) { + return float16; + } else if (str == ST_BF16) { + return bfloat16; + } else if (str == ST_I64) { + return int64; + } else if (str == ST_I32) { + return int32; + } else if (str == ST_I16) { + return int16; + } else if (str == ST_I8) { + return int8; + } else if (str == ST_U64) { + return uint64; + } else if (str == ST_U32) { + return uint32; + } else if (str == ST_U16) { + return uint16; + } else if (str == ST_U8) { + return uint8; + } else if (str == ST_BOOL) { + return bool_; + } else { + throw std::runtime_error("[safetensor] unsupported dtype " + str); + } +} + /** Load array from reader in safetensor format */ std::unordered_map load_safetensor( std::shared_ptr in_stream, @@ -185,12 +217,13 @@ std::unordered_map load_safetensor( char json[jsonHeaderLength]; in_stream->read(json, jsonHeaderLength); auto metadata = io::parseJson(json, jsonHeaderLength); + // Should always be an object on the top-level if (!metadata.is_type(io::JSONNode::Type::OBJECT)) { throw std::runtime_error( "[load_safetensor] Invalid json metadata " + in_stream->label()); } size_t offset = jsonHeaderLength + 8; - // Parse the json raw data + // Load the arrays using metadata std::unordered_map res; for (auto& [key, obj] : *metadata.getObject()) { std::string dtype = obj->getObject()->at("dtype")->getString(); @@ -204,23 +237,19 @@ std::unordered_map load_safetensor( for (const auto& offset : *data_offsets) { data_offsets_vec.push_back(offset->getNumber()); } - if (dtype == "F32") { - auto loaded_array = array( - shape_vec, - float32, - std::make_unique( - to_stream(s), - in_stream, - offset + data_offsets->at(0)->getNumber(), - offset + data_offsets->at(1)->getNumber(), - false), - std::vector{}); - res.insert({key, loaded_array}); - } + Dtype type = dtype_from_safe_tensor_str(dtype); + auto loaded_array = array( + shape_vec, + float32, + std::make_unique( + to_stream(s), + in_stream, + offset + data_offsets->at(0)->getNumber(), + offset + data_offsets->at(1)->getNumber(), + false), + std::vector{}); + res.insert({key, loaded_array}); } - // for (auto& [key, arr] : res) { - // arr.eval(); - // } return res; } diff --git a/mlx/safetensor.h b/mlx/safetensor.h index f6057dea5..662c4ee00 100644 --- a/mlx/safetensor.h +++ b/mlx/safetensor.h @@ -9,7 +9,26 @@ namespace mlx::core { +#define ST_F16 "F16" +#define ST_BF16 "BF16" +#define ST_F32 "F32" + +#define ST_BOOL "BOOL" +#define ST_I8 "I8" +#define ST_I16 "I16" +#define ST_I32 "I32" +#define ST_I64 "I64" +#define ST_U8 "U8" +#define ST_U16 "U16" +#define ST_U32 "U32" +#define ST_U64 "U64" + namespace io { +// NOTE: This json parser is a bare minimum implementation for safetensors, +// it does not support all of json features, and does not have alot of edge case +// catches. This is okay as safe tensor json is very simple and we can assume it +// is always valid and well formed, but this should not be used for general json +// parsing class JSONNode; using JSONObject = std::unordered_map; using JSONList = std::vector;