mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 04:56:41 +08:00
switch to unordered map
This commit is contained in:
parent
fef579cec1
commit
9be3ea69ee
@ -1058,10 +1058,10 @@ array dequantize(
|
|||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Load array from .safetensor file format */
|
/** Load array from .safetensor file format */
|
||||||
std::map<std::string, array> load_safetensor(
|
std::unordered_map<std::string, array> load_safetensor(
|
||||||
std::shared_ptr<io::Reader> in_stream,
|
std::shared_ptr<io::Reader> in_stream,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
std::map<std::string, array> load_safetensor(
|
std::unordered_map<std::string, array> load_safetensor(
|
||||||
const std::string& file,
|
const std::string& file,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
@ -165,7 +165,7 @@ JSONNode parseJson(const char* data, size_t len) {
|
|||||||
} // namespace io
|
} // namespace io
|
||||||
|
|
||||||
/** Load array from reader in safetensor format */
|
/** Load array from reader in safetensor format */
|
||||||
std::map<std::string, array> load_safetensor(
|
std::unordered_map<std::string, array> load_safetensor(
|
||||||
std::shared_ptr<io::Reader> in_stream,
|
std::shared_ptr<io::Reader> in_stream,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////
|
||||||
@ -190,27 +190,27 @@ std::map<std::string, array> load_safetensor(
|
|||||||
"[load_safetensor] Invalid json metadata " + in_stream->label());
|
"[load_safetensor] Invalid json metadata " + in_stream->label());
|
||||||
}
|
}
|
||||||
// Parse the json raw data
|
// Parse the json raw data
|
||||||
std::map<std::string, array> res;
|
std::unordered_map<std::string, array> res;
|
||||||
for (const auto& key : *metadata.getObject()) {
|
for (auto& [key, obj] : *metadata.getObject()) {
|
||||||
std::string dtype = key.second->getObject()->at("dtype")->getString();
|
std::string dtype = obj->getObject()->at("dtype")->getString();
|
||||||
auto shape = key.second->getObject()->at("shape")->getList();
|
auto shape = obj->getObject()->at("shape")->getList();
|
||||||
std::vector<int> shape_vec;
|
std::vector<int> shape_vec;
|
||||||
for (const auto& dim : *shape) {
|
for (const auto& dim : *shape) {
|
||||||
shape_vec.push_back(dim->getNumber());
|
shape_vec.push_back(dim->getNumber());
|
||||||
}
|
}
|
||||||
auto data_offsets = key.second->getObject()->at("data_offsets")->getList();
|
auto data_offsets = obj->getObject()->at("data_offsets")->getList();
|
||||||
std::vector<int64_t> data_offsets_vec;
|
std::vector<int64_t> data_offsets_vec;
|
||||||
for (const auto& offset : *data_offsets) {
|
for (const auto& offset : *data_offsets) {
|
||||||
data_offsets_vec.push_back(offset->getNumber());
|
data_offsets_vec.push_back(offset->getNumber());
|
||||||
}
|
}
|
||||||
if (dtype == "F32") {
|
if (dtype == "F32") {
|
||||||
res.insert({key.first, zeros(shape_vec, s)});
|
res.insert({key, zeros(shape_vec, s)});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::map<std::string, array> load_safetensor(
|
std::unordered_map<std::string, array> load_safetensor(
|
||||||
const std::string& file,
|
const std::string& file,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
return load_safetensor(std::make_shared<io::FileReader>(file), s);
|
return load_safetensor(std::make_shared<io::FileReader>(file), s);
|
||||||
|
@ -2,8 +2,6 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <map>
|
|
||||||
|
|
||||||
#include "mlx/load.h"
|
#include "mlx/load.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
@ -13,7 +11,7 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace io {
|
namespace io {
|
||||||
class JSONNode;
|
class JSONNode;
|
||||||
using JSONObject = std::map<std::string, JSONNode*>;
|
using JSONObject = std::unordered_map<std::string, JSONNode*>;
|
||||||
using JSONList = std::vector<JSONNode*>;
|
using JSONList = std::vector<JSONNode*>;
|
||||||
|
|
||||||
class JSONNode {
|
class JSONNode {
|
||||||
|
@ -58,36 +58,6 @@ TEST_CASE("test parseJson") {
|
|||||||
CHECK(res.getList()->at(1)->is_type(io::JSONNode::Type::STRING));
|
CHECK(res.getList()->at(1)->is_type(io::JSONNode::Type::STRING));
|
||||||
CHECK_EQ(res.getList()->at(1)->getString(), "test");
|
CHECK_EQ(res.getList()->at(1)->getString(), "test");
|
||||||
|
|
||||||
raw = std::string("{\"test\": \"test\", \"test_num\": 1}");
|
|
||||||
res = io::parseJson(raw.c_str(), raw.size());
|
|
||||||
CHECK(res.is_type(io::JSONNode::Type::OBJECT));
|
|
||||||
CHECK_EQ(res.getObject()->size(), 2);
|
|
||||||
CHECK(res.getObject()->at("test")->is_type(io::JSONNode::Type::STRING));
|
|
||||||
CHECK_EQ(res.getObject()->at("test")->getString(), "test");
|
|
||||||
CHECK(res.getObject()->at("test_num")->is_type(io::JSONNode::Type::NUMBER));
|
|
||||||
CHECK_EQ(res.getObject()->at("test_num")->getNumber(), 1);
|
|
||||||
|
|
||||||
raw = std::string("{\"test\": { \"test\": \"test\"}}");
|
|
||||||
res = io::parseJson(raw.c_str(), raw.size());
|
|
||||||
CHECK(res.is_type(io::JSONNode::Type::OBJECT));
|
|
||||||
CHECK_EQ(res.getObject()->size(), 1);
|
|
||||||
CHECK(res.getObject()->at("test")->is_type(io::JSONNode::Type::OBJECT));
|
|
||||||
CHECK_EQ(res.getObject()->at("test")->getObject()->size(), 1);
|
|
||||||
CHECK(res.getObject()->at("test")->getObject()->at("test")->is_type(
|
|
||||||
io::JSONNode::Type::STRING));
|
|
||||||
|
|
||||||
raw = std::string("{\"test\":[1, 2]}");
|
|
||||||
res = io::parseJson(raw.c_str(), raw.size());
|
|
||||||
CHECK(res.is_type(io::JSONNode::Type::OBJECT));
|
|
||||||
CHECK_EQ(res.getObject()->size(), 1);
|
|
||||||
CHECK(res.getObject()->at("test")->is_type(io::JSONNode::Type::LIST));
|
|
||||||
CHECK_EQ(res.getObject()->at("test")->getList()->size(), 2);
|
|
||||||
CHECK(res.getObject()->at("test")->getList()->at(0)->is_type(
|
|
||||||
io::JSONNode::Type::NUMBER));
|
|
||||||
CHECK_EQ(res.getObject()->at("test")->getList()->at(0)->getNumber(), 1);
|
|
||||||
CHECK(res.getObject()->at("test")->getList()->at(1)->is_type(
|
|
||||||
io::JSONNode::Type::NUMBER));
|
|
||||||
CHECK_EQ(res.getObject()->at("test")->getList()->at(1)->getNumber(), 2);
|
|
||||||
raw = std::string(
|
raw = std::string(
|
||||||
"{\"test\":{\"dtype\":\"F32\",\"shape\":[4], \"data_offsets\":[0, 16]}}");
|
"{\"test\":{\"dtype\":\"F32\",\"shape\":[4], \"data_offsets\":[0, 16]}}");
|
||||||
res = io::parseJson(raw.c_str(), raw.size());
|
res = io::parseJson(raw.c_str(), raw.size());
|
||||||
|
Loading…
Reference in New Issue
Block a user