mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 12:06:42 +08:00
saving works
This commit is contained in:
parent
5aa0b1f632
commit
f09bcc7d50
2
.gitignore
vendored
2
.gitignore
vendored
@ -8,7 +8,7 @@ __pycache__/
|
||||
|
||||
# tensor files
|
||||
*.safe
|
||||
*.safetensor
|
||||
*.safetensors
|
||||
|
||||
# Metal libraries
|
||||
*.metallib
|
||||
|
@ -1065,4 +1065,10 @@ std::unordered_map<std::string, array> load_safetensor(
|
||||
const std::string& file,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
void save_safetensor(
|
||||
std::shared_ptr<io::Reader> in_stream,
|
||||
std::unordered_map<std::string, array>);
|
||||
void save_safetensor(
|
||||
const std::string& file,
|
||||
std::unordered_map<std::string, array>);
|
||||
} // namespace mlx::core
|
||||
|
@ -55,7 +55,7 @@ Token Tokenizer::getToken() {
|
||||
}
|
||||
}
|
||||
|
||||
JSONNode parseJson(const char* data, size_t len) {
|
||||
JSONNode jsonDeserialize(const char* data, size_t len) {
|
||||
auto tokenizer = Tokenizer(data, len);
|
||||
std::stack<JSONNode*> ctx;
|
||||
while (tokenizer.hasMoreTokens()) {
|
||||
@ -137,7 +137,7 @@ JSONNode parseJson(const char* data, size_t len) {
|
||||
case TOKEN::NUMBER: {
|
||||
// TODO: is there an easier way of doing this.
|
||||
auto str = new std::string(data + token.start, token.end - token.start);
|
||||
float val = strtof(str->c_str(), nullptr);
|
||||
auto val = strtoul(str->c_str(), nullptr, 10);
|
||||
if (ctx.top()->is_type(JSONNode::Type::LIST)) {
|
||||
ctx.top()->getList()->push_back(new JSONNode(val));
|
||||
} else if (ctx.top()->is_type(JSONNode::Type::STRING)) {
|
||||
@ -152,21 +152,83 @@ JSONNode parseJson(const char* data, size_t len) {
|
||||
}
|
||||
break;
|
||||
}
|
||||
case TOKEN::COMMA:
|
||||
break;
|
||||
case TOKEN::COLON:
|
||||
break;
|
||||
case TOKEN::NULL_TYPE:
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
throw std::runtime_error(
|
||||
"[parseJson] json was invalid and could not be parsed");
|
||||
"[jsonDeserialize] json was invalid and could not be parsed");
|
||||
}
|
||||
|
||||
std::string jsonSerialize(JSONNode* node) {
|
||||
std::string res;
|
||||
if (node->is_type(JSONNode::Type::STRING)) {
|
||||
return "\"" + node->getString() + "\"";
|
||||
}
|
||||
if (node->is_type(JSONNode::Type::NUMBER)) {
|
||||
return std::to_string(node->getNumber());
|
||||
}
|
||||
if (node->is_type(JSONNode::Type::LIST)) {
|
||||
res += "[";
|
||||
for (auto& item : *node->getList()) {
|
||||
res += jsonSerialize(item);
|
||||
res += ",";
|
||||
}
|
||||
if (res.back() == ',') {
|
||||
res.pop_back();
|
||||
}
|
||||
res += "]";
|
||||
return res;
|
||||
}
|
||||
if (node->is_type(JSONNode::Type::OBJECT)) {
|
||||
res += "{";
|
||||
for (auto& [key, item] : *node->getObject()) {
|
||||
res += "\"" + key + "\":";
|
||||
res += jsonSerialize(item);
|
||||
res += ",";
|
||||
}
|
||||
if (res.back() == ',') {
|
||||
res.pop_back();
|
||||
}
|
||||
res += "}";
|
||||
return res;
|
||||
}
|
||||
|
||||
throw std::runtime_error("[jsonSerialize] invalid json node");
|
||||
}
|
||||
|
||||
} // namespace io
|
||||
std::string dtype_to_safetensor_str(Dtype t) {
|
||||
if (t == float32) {
|
||||
return ST_F32;
|
||||
} else if (t == bfloat16) {
|
||||
return ST_BF16;
|
||||
} else if (t == float16) {
|
||||
return ST_F16;
|
||||
} else if (t == int64) {
|
||||
return ST_I64;
|
||||
} else if (t == int32) {
|
||||
return ST_I32;
|
||||
} else if (t == int16) {
|
||||
return ST_I16;
|
||||
} else if (t == int8) {
|
||||
return ST_I8;
|
||||
} else if (t == uint64) {
|
||||
return ST_U64;
|
||||
} else if (t == uint32) {
|
||||
return ST_U32;
|
||||
} else if (t == uint16) {
|
||||
return ST_U16;
|
||||
} else if (t == uint8) {
|
||||
return ST_U8;
|
||||
} else if (t == bool_) {
|
||||
return ST_BOOL;
|
||||
} else {
|
||||
throw std::runtime_error("[safetensor] unsupported dtype");
|
||||
}
|
||||
}
|
||||
|
||||
Dtype dtype_from_safe_tensor_str(std::string str) {
|
||||
Dtype dtype_from_safetensor_str(std::string str) {
|
||||
if (str == ST_F32) {
|
||||
return float32;
|
||||
} else if (str == ST_F16) {
|
||||
@ -216,7 +278,7 @@ std::unordered_map<std::string, array> load_safetensor(
|
||||
// Load the json metadata
|
||||
char json[jsonHeaderLength];
|
||||
in_stream->read(json, jsonHeaderLength);
|
||||
auto metadata = io::parseJson(json, jsonHeaderLength);
|
||||
auto metadata = io::jsonDeserialize(json, jsonHeaderLength);
|
||||
// Should always be an object on the top-level
|
||||
if (!metadata.is_type(io::JSONNode::Type::OBJECT)) {
|
||||
throw std::runtime_error(
|
||||
@ -237,7 +299,7 @@ std::unordered_map<std::string, array> load_safetensor(
|
||||
for (const auto& offset : *data_offsets) {
|
||||
data_offsets_vec.push_back(offset->getNumber());
|
||||
}
|
||||
Dtype type = dtype_from_safe_tensor_str(dtype);
|
||||
Dtype type = dtype_from_safetensor_str(dtype);
|
||||
auto loaded_array = array(
|
||||
shape_vec,
|
||||
float32,
|
||||
@ -259,4 +321,82 @@ std::unordered_map<std::string, array> load_safetensor(
|
||||
return load_safetensor(std::make_shared<io::FileReader>(file), s);
|
||||
}
|
||||
|
||||
/** Save array to out stream in .npy format */
|
||||
void save_safetensor(
|
||||
std::shared_ptr<io::Writer> out_stream,
|
||||
std::unordered_map<std::string, array> a) {
|
||||
////////////////////////////////////////////////////////
|
||||
// Check array map
|
||||
|
||||
io::JSONNode metadata(io::JSONNode::Type::OBJECT);
|
||||
size_t offset = 0;
|
||||
for (auto& [key, arr] : a) {
|
||||
arr.eval(false);
|
||||
if (arr.nbytes() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"[save_safetensor] cannot serialize an empty array key: " + key);
|
||||
}
|
||||
|
||||
if (!arr.flags().contiguous) {
|
||||
throw std::invalid_argument(
|
||||
"[save_safetensor] cannot serialize a non-contiguous array key: " +
|
||||
key);
|
||||
}
|
||||
auto obj = new io::JSONNode(io::JSONNode::Type::OBJECT);
|
||||
// TODO: dont make a new string
|
||||
obj->getObject()->insert(
|
||||
{"dtype",
|
||||
new io::JSONNode(
|
||||
new std::string(dtype_to_safetensor_str(arr.dtype())))});
|
||||
obj->getObject()->insert(
|
||||
{"shape", new io::JSONNode(io::JSONNode::Type::LIST)});
|
||||
for (auto& dim : arr.shape()) {
|
||||
obj->getObject()->at("shape")->getList()->push_back(
|
||||
new io::JSONNode(dim));
|
||||
}
|
||||
obj->getObject()->insert(
|
||||
{"data_offsets", new io::JSONNode(io::JSONNode::Type::LIST)});
|
||||
obj->getObject()
|
||||
->at("data_offsets")
|
||||
->getList()
|
||||
->push_back(new io::JSONNode(offset));
|
||||
obj->getObject()
|
||||
->at("data_offsets")
|
||||
->getList()
|
||||
->push_back(new io::JSONNode(offset + arr.nbytes()));
|
||||
metadata.getObject()->insert({key, obj});
|
||||
offset += arr.nbytes();
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////
|
||||
// Check file
|
||||
if (!out_stream->good() || !out_stream->is_open()) {
|
||||
throw std::runtime_error(
|
||||
"[save_safetensor] Failed to open " + out_stream->label());
|
||||
}
|
||||
|
||||
auto header = io::jsonSerialize(&metadata);
|
||||
uint64_t header_len = header.length();
|
||||
out_stream->write(reinterpret_cast<char*>(&header_len), 8);
|
||||
out_stream->write(header.c_str(), header_len);
|
||||
for (auto& [key, arr] : a) {
|
||||
out_stream->write(arr.data<char>(), arr.nbytes());
|
||||
}
|
||||
}
|
||||
|
||||
void save_safetensor(
|
||||
const std::string& file_,
|
||||
std::unordered_map<std::string, array> a) {
|
||||
// Open and check file
|
||||
std::string file = file_;
|
||||
|
||||
// Add .npy to file name if it is not there
|
||||
if (file.length() < 12 ||
|
||||
file.substr(file.length() - 12, 12) != ".safetensors")
|
||||
file += ".safetensors";
|
||||
|
||||
// Serialize array
|
||||
save_safetensor(std::make_shared<io::FileWriter>(file), a);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -74,7 +74,7 @@ class JSONNode {
|
||||
return *this->_values.s;
|
||||
}
|
||||
|
||||
float getNumber() {
|
||||
uint32_t getNumber() {
|
||||
if (!is_type(Type::NUMBER)) {
|
||||
throw new std::runtime_error("not a number");
|
||||
}
|
||||
@ -94,12 +94,13 @@ class JSONNode {
|
||||
JSONObject* object;
|
||||
JSONList* list;
|
||||
std::string* s;
|
||||
float f;
|
||||
uint32_t f;
|
||||
} _values;
|
||||
Type _type;
|
||||
};
|
||||
|
||||
JSONNode parseJson(const char* data, size_t len);
|
||||
JSONNode jsonDeserialize(const char* data, size_t len);
|
||||
std::string jsonSerialize(JSONNode* node);
|
||||
|
||||
enum class TOKEN {
|
||||
CURLY_OPEN,
|
||||
|
@ -38,20 +38,42 @@ TEST_CASE("test tokenizer") {
|
||||
CHECK_EQ(tokenizer.getToken().type, io::TOKEN::NULL_TYPE);
|
||||
}
|
||||
|
||||
TEST_CASE("test parseJson") {
|
||||
TEST_CASE("test jsonSerialize") {
|
||||
auto test = new io::JSONNode(io::JSONNode::Type::OBJECT);
|
||||
auto src = io::jsonSerialize(test);
|
||||
CHECK_EQ(src, "{}");
|
||||
test = new io::JSONNode(io::JSONNode::Type::LIST);
|
||||
src = io::jsonSerialize(test);
|
||||
CHECK_EQ(src, "[]");
|
||||
test = new io::JSONNode(io::JSONNode::Type::OBJECT);
|
||||
test->getObject()->insert(
|
||||
{"test", new io::JSONNode(new std::string("testing"))});
|
||||
src = io::jsonSerialize(test);
|
||||
CHECK_EQ(src, "{\"test\":\"testing\"}");
|
||||
test = new io::JSONNode(io::JSONNode::Type::OBJECT);
|
||||
auto arr = new io::JSONNode(io::JSONNode::Type::LIST);
|
||||
arr->getList()->push_back(new io::JSONNode(1));
|
||||
arr->getList()->push_back(new io::JSONNode(2));
|
||||
test->getObject()->insert({"test", arr});
|
||||
src = io::jsonSerialize(test);
|
||||
CHECK_EQ(src, "{\"test\":[1,2]}");
|
||||
}
|
||||
|
||||
TEST_CASE("test jsonDeserialize") {
|
||||
auto raw = std::string("{}");
|
||||
auto res = io::parseJson(raw.c_str(), raw.size());
|
||||
auto res = io::jsonDeserialize(raw.c_str(), raw.size());
|
||||
CHECK(res.is_type(io::JSONNode::Type::OBJECT));
|
||||
|
||||
raw = std::string("[]");
|
||||
res = io::parseJson(raw.c_str(), raw.size());
|
||||
res = io::jsonDeserialize(raw.c_str(), raw.size());
|
||||
CHECK(res.is_type(io::JSONNode::Type::LIST));
|
||||
|
||||
raw = std::string("[");
|
||||
CHECK_THROWS_AS(io::parseJson(raw.c_str(), raw.size()), std::runtime_error);
|
||||
CHECK_THROWS_AS(
|
||||
io::jsonDeserialize(raw.c_str(), raw.size()), std::runtime_error);
|
||||
|
||||
raw = std::string("[{}, \"test\"]");
|
||||
res = io::parseJson(raw.c_str(), raw.size());
|
||||
res = io::jsonDeserialize(raw.c_str(), raw.size());
|
||||
CHECK(res.is_type(io::JSONNode::Type::LIST));
|
||||
CHECK_EQ(res.getList()->size(), 2);
|
||||
CHECK(res.getList()->at(0)->is_type(io::JSONNode::Type::OBJECT));
|
||||
@ -60,7 +82,7 @@ TEST_CASE("test parseJson") {
|
||||
|
||||
raw = std::string(
|
||||
"{\"test\":{\"dtype\":\"F32\",\"shape\":[4], \"data_offsets\":[0, 16]}}");
|
||||
res = io::parseJson(raw.c_str(), raw.size());
|
||||
res = io::jsonDeserialize(raw.c_str(), raw.size());
|
||||
CHECK(res.is_type(io::JSONNode::Type::OBJECT));
|
||||
CHECK_EQ(res.getObject()->size(), 1);
|
||||
CHECK(res.getObject()->at("test")->is_type(io::JSONNode::Type::OBJECT));
|
||||
@ -138,8 +160,16 @@ TEST_CASE("test parseJson") {
|
||||
16);
|
||||
}
|
||||
|
||||
TEST_CASE("test save_safetensor") {
|
||||
auto map = std::unordered_map<std::string, array>();
|
||||
map.insert({"test", array({1.0, 2.0, 3.0, 4.0})});
|
||||
map.insert({"test2", ones({2, 2})});
|
||||
MESSAGE("SAVING");
|
||||
save_safetensor("../../temp1", map);
|
||||
}
|
||||
|
||||
TEST_CASE("test load_safetensor") {
|
||||
auto safeDict = load_safetensor("../../temp.safe");
|
||||
auto safeDict = load_safetensor("../../temp1.safetensors");
|
||||
CHECK_EQ(safeDict.size(), 2);
|
||||
CHECK_EQ(safeDict.count("test"), 1);
|
||||
CHECK_EQ(safeDict.count("test2"), 1);
|
||||
|
Loading…
Reference in New Issue
Block a user