switch to unordered map

This commit is contained in:
dc-dc-dc 2023-12-18 11:57:46 -05:00
parent fef579cec1
commit 9be3ea69ee
4 changed files with 11 additions and 43 deletions

View File

@ -1058,10 +1058,10 @@ array dequantize(
StreamOrDevice s = {});
/** Load array from .safetensor file format */
std::map<std::string, array> load_safetensor(
std::unordered_map<std::string, array> load_safetensor(
std::shared_ptr<io::Reader> in_stream,
StreamOrDevice s = {});
std::map<std::string, array> load_safetensor(
std::unordered_map<std::string, array> load_safetensor(
const std::string& file,
StreamOrDevice s = {});

View File

@ -165,7 +165,7 @@ JSONNode parseJson(const char* data, size_t len) {
} // namespace io
/** Load array from reader in safetensor format */
std::map<std::string, array> load_safetensor(
std::unordered_map<std::string, array> load_safetensor(
std::shared_ptr<io::Reader> in_stream,
StreamOrDevice s) {
////////////////////////////////////////////////////////
@ -190,27 +190,27 @@ std::map<std::string, array> load_safetensor(
"[load_safetensor] Invalid json metadata " + in_stream->label());
}
// Parse the json raw data
std::map<std::string, array> res;
for (const auto& key : *metadata.getObject()) {
std::string dtype = key.second->getObject()->at("dtype")->getString();
auto shape = key.second->getObject()->at("shape")->getList();
std::unordered_map<std::string, array> res;
for (auto& [key, obj] : *metadata.getObject()) {
std::string dtype = obj->getObject()->at("dtype")->getString();
auto shape = obj->getObject()->at("shape")->getList();
std::vector<int> shape_vec;
for (const auto& dim : *shape) {
shape_vec.push_back(dim->getNumber());
}
auto data_offsets = key.second->getObject()->at("data_offsets")->getList();
auto data_offsets = obj->getObject()->at("data_offsets")->getList();
std::vector<int64_t> data_offsets_vec;
for (const auto& offset : *data_offsets) {
data_offsets_vec.push_back(offset->getNumber());
}
if (dtype == "F32") {
res.insert({key.first, zeros(shape_vec, s)});
res.insert({key, zeros(shape_vec, s)});
}
}
return res;
}
std::map<std::string, array> load_safetensor(
std::unordered_map<std::string, array> load_safetensor(
const std::string& file,
StreamOrDevice s) {
return load_safetensor(std::make_shared<io::FileReader>(file), s);

View File

@ -2,8 +2,6 @@
#pragma once
#include <map>
#include "mlx/load.h"
#include "mlx/ops.h"
#include "mlx/primitives.h"
@ -13,7 +11,7 @@ namespace mlx::core {
namespace io {
class JSONNode;
using JSONObject = std::map<std::string, JSONNode*>;
using JSONObject = std::unordered_map<std::string, JSONNode*>;
using JSONList = std::vector<JSONNode*>;
class JSONNode {

View File

@ -58,36 +58,6 @@ TEST_CASE("test parseJson") {
CHECK(res.getList()->at(1)->is_type(io::JSONNode::Type::STRING));
CHECK_EQ(res.getList()->at(1)->getString(), "test");
raw = std::string("{\"test\": \"test\", \"test_num\": 1}");
res = io::parseJson(raw.c_str(), raw.size());
CHECK(res.is_type(io::JSONNode::Type::OBJECT));
CHECK_EQ(res.getObject()->size(), 2);
CHECK(res.getObject()->at("test")->is_type(io::JSONNode::Type::STRING));
CHECK_EQ(res.getObject()->at("test")->getString(), "test");
CHECK(res.getObject()->at("test_num")->is_type(io::JSONNode::Type::NUMBER));
CHECK_EQ(res.getObject()->at("test_num")->getNumber(), 1);
raw = std::string("{\"test\": { \"test\": \"test\"}}");
res = io::parseJson(raw.c_str(), raw.size());
CHECK(res.is_type(io::JSONNode::Type::OBJECT));
CHECK_EQ(res.getObject()->size(), 1);
CHECK(res.getObject()->at("test")->is_type(io::JSONNode::Type::OBJECT));
CHECK_EQ(res.getObject()->at("test")->getObject()->size(), 1);
CHECK(res.getObject()->at("test")->getObject()->at("test")->is_type(
io::JSONNode::Type::STRING));
raw = std::string("{\"test\":[1, 2]}");
res = io::parseJson(raw.c_str(), raw.size());
CHECK(res.is_type(io::JSONNode::Type::OBJECT));
CHECK_EQ(res.getObject()->size(), 1);
CHECK(res.getObject()->at("test")->is_type(io::JSONNode::Type::LIST));
CHECK_EQ(res.getObject()->at("test")->getList()->size(), 2);
CHECK(res.getObject()->at("test")->getList()->at(0)->is_type(
io::JSONNode::Type::NUMBER));
CHECK_EQ(res.getObject()->at("test")->getList()->at(0)->getNumber(), 1);
CHECK(res.getObject()->at("test")->getList()->at(1)->is_type(
io::JSONNode::Type::NUMBER));
CHECK_EQ(res.getObject()->at("test")->getList()->at(1)->getNumber(), 2);
raw = std::string(
"{\"test\":{\"dtype\":\"F32\",\"shape\":[4], \"data_offsets\":[0, 16]}}");
res = io::parseJson(raw.c_str(), raw.size());