mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
switch to unordered map
This commit is contained in:
parent
fef579cec1
commit
9be3ea69ee
@ -1058,10 +1058,10 @@ array dequantize(
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Load array from .safetensor file format */
|
||||
std::map<std::string, array> load_safetensor(
|
||||
std::unordered_map<std::string, array> load_safetensor(
|
||||
std::shared_ptr<io::Reader> in_stream,
|
||||
StreamOrDevice s = {});
|
||||
std::map<std::string, array> load_safetensor(
|
||||
std::unordered_map<std::string, array> load_safetensor(
|
||||
const std::string& file,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
|
@ -165,7 +165,7 @@ JSONNode parseJson(const char* data, size_t len) {
|
||||
} // namespace io
|
||||
|
||||
/** Load array from reader in safetensor format */
|
||||
std::map<std::string, array> load_safetensor(
|
||||
std::unordered_map<std::string, array> load_safetensor(
|
||||
std::shared_ptr<io::Reader> in_stream,
|
||||
StreamOrDevice s) {
|
||||
////////////////////////////////////////////////////////
|
||||
@ -190,27 +190,27 @@ std::map<std::string, array> load_safetensor(
|
||||
"[load_safetensor] Invalid json metadata " + in_stream->label());
|
||||
}
|
||||
// Parse the json raw data
|
||||
std::map<std::string, array> res;
|
||||
for (const auto& key : *metadata.getObject()) {
|
||||
std::string dtype = key.second->getObject()->at("dtype")->getString();
|
||||
auto shape = key.second->getObject()->at("shape")->getList();
|
||||
std::unordered_map<std::string, array> res;
|
||||
for (auto& [key, obj] : *metadata.getObject()) {
|
||||
std::string dtype = obj->getObject()->at("dtype")->getString();
|
||||
auto shape = obj->getObject()->at("shape")->getList();
|
||||
std::vector<int> shape_vec;
|
||||
for (const auto& dim : *shape) {
|
||||
shape_vec.push_back(dim->getNumber());
|
||||
}
|
||||
auto data_offsets = key.second->getObject()->at("data_offsets")->getList();
|
||||
auto data_offsets = obj->getObject()->at("data_offsets")->getList();
|
||||
std::vector<int64_t> data_offsets_vec;
|
||||
for (const auto& offset : *data_offsets) {
|
||||
data_offsets_vec.push_back(offset->getNumber());
|
||||
}
|
||||
if (dtype == "F32") {
|
||||
res.insert({key.first, zeros(shape_vec, s)});
|
||||
res.insert({key, zeros(shape_vec, s)});
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
std::map<std::string, array> load_safetensor(
|
||||
std::unordered_map<std::string, array> load_safetensor(
|
||||
const std::string& file,
|
||||
StreamOrDevice s) {
|
||||
return load_safetensor(std::make_shared<io::FileReader>(file), s);
|
||||
|
@ -2,8 +2,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "mlx/load.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
@ -13,7 +11,7 @@ namespace mlx::core {
|
||||
|
||||
namespace io {
|
||||
class JSONNode;
|
||||
using JSONObject = std::map<std::string, JSONNode*>;
|
||||
using JSONObject = std::unordered_map<std::string, JSONNode*>;
|
||||
using JSONList = std::vector<JSONNode*>;
|
||||
|
||||
class JSONNode {
|
||||
|
@ -58,36 +58,6 @@ TEST_CASE("test parseJson") {
|
||||
CHECK(res.getList()->at(1)->is_type(io::JSONNode::Type::STRING));
|
||||
CHECK_EQ(res.getList()->at(1)->getString(), "test");
|
||||
|
||||
raw = std::string("{\"test\": \"test\", \"test_num\": 1}");
|
||||
res = io::parseJson(raw.c_str(), raw.size());
|
||||
CHECK(res.is_type(io::JSONNode::Type::OBJECT));
|
||||
CHECK_EQ(res.getObject()->size(), 2);
|
||||
CHECK(res.getObject()->at("test")->is_type(io::JSONNode::Type::STRING));
|
||||
CHECK_EQ(res.getObject()->at("test")->getString(), "test");
|
||||
CHECK(res.getObject()->at("test_num")->is_type(io::JSONNode::Type::NUMBER));
|
||||
CHECK_EQ(res.getObject()->at("test_num")->getNumber(), 1);
|
||||
|
||||
raw = std::string("{\"test\": { \"test\": \"test\"}}");
|
||||
res = io::parseJson(raw.c_str(), raw.size());
|
||||
CHECK(res.is_type(io::JSONNode::Type::OBJECT));
|
||||
CHECK_EQ(res.getObject()->size(), 1);
|
||||
CHECK(res.getObject()->at("test")->is_type(io::JSONNode::Type::OBJECT));
|
||||
CHECK_EQ(res.getObject()->at("test")->getObject()->size(), 1);
|
||||
CHECK(res.getObject()->at("test")->getObject()->at("test")->is_type(
|
||||
io::JSONNode::Type::STRING));
|
||||
|
||||
raw = std::string("{\"test\":[1, 2]}");
|
||||
res = io::parseJson(raw.c_str(), raw.size());
|
||||
CHECK(res.is_type(io::JSONNode::Type::OBJECT));
|
||||
CHECK_EQ(res.getObject()->size(), 1);
|
||||
CHECK(res.getObject()->at("test")->is_type(io::JSONNode::Type::LIST));
|
||||
CHECK_EQ(res.getObject()->at("test")->getList()->size(), 2);
|
||||
CHECK(res.getObject()->at("test")->getList()->at(0)->is_type(
|
||||
io::JSONNode::Type::NUMBER));
|
||||
CHECK_EQ(res.getObject()->at("test")->getList()->at(0)->getNumber(), 1);
|
||||
CHECK(res.getObject()->at("test")->getList()->at(1)->is_type(
|
||||
io::JSONNode::Type::NUMBER));
|
||||
CHECK_EQ(res.getObject()->at("test")->getList()->at(1)->getNumber(), 2);
|
||||
raw = std::string(
|
||||
"{\"test\":{\"dtype\":\"F32\",\"shape\":[4], \"data_offsets\":[0, 16]}}");
|
||||
res = io::parseJson(raw.c_str(), raw.size());
|
||||
|
Loading…
Reference in New Issue
Block a user