From 9be3ea69eea45d9ee607b32c60869e514e242dd2 Mon Sep 17 00:00:00 2001 From: dc-dc-dc Date: Mon, 18 Dec 2023 11:57:46 -0500 Subject: [PATCH] switch to unordered map --- mlx/ops.h | 4 ++-- mlx/safetensor.cpp | 16 ++++++++-------- mlx/safetensor.h | 4 +--- tests/load_tests.cpp | 30 ------------------------------ 4 files changed, 11 insertions(+), 43 deletions(-) diff --git a/mlx/ops.h b/mlx/ops.h index fed4a4adb..bd0521cc8 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1058,10 +1058,10 @@ array dequantize( StreamOrDevice s = {}); /** Load array from .safetensor file format */ -std::map load_safetensor( +std::unordered_map load_safetensor( std::shared_ptr in_stream, StreamOrDevice s = {}); -std::map load_safetensor( +std::unordered_map load_safetensor( const std::string& file, StreamOrDevice s = {}); diff --git a/mlx/safetensor.cpp b/mlx/safetensor.cpp index d986da6e7..072e46177 100644 --- a/mlx/safetensor.cpp +++ b/mlx/safetensor.cpp @@ -165,7 +165,7 @@ JSONNode parseJson(const char* data, size_t len) { } // namespace io /** Load array from reader in safetensor format */ -std::map load_safetensor( +std::unordered_map load_safetensor( std::shared_ptr in_stream, StreamOrDevice s) { //////////////////////////////////////////////////////// @@ -190,27 +190,27 @@ std::map load_safetensor( "[load_safetensor] Invalid json metadata " + in_stream->label()); } // Parse the json raw data - std::map res; - for (const auto& key : *metadata.getObject()) { - std::string dtype = key.second->getObject()->at("dtype")->getString(); - auto shape = key.second->getObject()->at("shape")->getList(); + std::unordered_map res; + for (auto& [key, obj] : *metadata.getObject()) { + std::string dtype = obj->getObject()->at("dtype")->getString(); + auto shape = obj->getObject()->at("shape")->getList(); std::vector shape_vec; for (const auto& dim : *shape) { shape_vec.push_back(dim->getNumber()); } - auto data_offsets = key.second->getObject()->at("data_offsets")->getList(); + auto data_offsets = obj->getObject()->at("data_offsets")->getList(); std::vector data_offsets_vec; for (const auto& offset : *data_offsets) { data_offsets_vec.push_back(offset->getNumber()); } if (dtype == "F32") { - res.insert({key.first, zeros(shape_vec, s)}); + res.insert({key, zeros(shape_vec, s)}); } } return res; } -std::map load_safetensor( +std::unordered_map load_safetensor( const std::string& file, StreamOrDevice s) { return load_safetensor(std::make_shared(file), s); diff --git a/mlx/safetensor.h b/mlx/safetensor.h index 941b210c5..f6057dea5 100644 --- a/mlx/safetensor.h +++ b/mlx/safetensor.h @@ -2,8 +2,6 @@ #pragma once -#include - #include "mlx/load.h" #include "mlx/ops.h" #include "mlx/primitives.h" @@ -13,7 +11,7 @@ namespace mlx::core { namespace io { class JSONNode; -using JSONObject = std::map; +using JSONObject = std::unordered_map; using JSONList = std::vector; class JSONNode { diff --git a/tests/load_tests.cpp b/tests/load_tests.cpp index 06f003c2a..198fcddd0 100644 --- a/tests/load_tests.cpp +++ b/tests/load_tests.cpp @@ -58,36 +58,6 @@ TEST_CASE("test parseJson") { CHECK(res.getList()->at(1)->is_type(io::JSONNode::Type::STRING)); CHECK_EQ(res.getList()->at(1)->getString(), "test"); - raw = std::string("{\"test\": \"test\", \"test_num\": 1}"); - res = io::parseJson(raw.c_str(), raw.size()); - CHECK(res.is_type(io::JSONNode::Type::OBJECT)); - CHECK_EQ(res.getObject()->size(), 2); - CHECK(res.getObject()->at("test")->is_type(io::JSONNode::Type::STRING)); - CHECK_EQ(res.getObject()->at("test")->getString(), "test"); - CHECK(res.getObject()->at("test_num")->is_type(io::JSONNode::Type::NUMBER)); - CHECK_EQ(res.getObject()->at("test_num")->getNumber(), 1); - - raw = std::string("{\"test\": { \"test\": \"test\"}}"); - res = io::parseJson(raw.c_str(), raw.size()); - CHECK(res.is_type(io::JSONNode::Type::OBJECT)); - CHECK_EQ(res.getObject()->size(), 1); - CHECK(res.getObject()->at("test")->is_type(io::JSONNode::Type::OBJECT)); - CHECK_EQ(res.getObject()->at("test")->getObject()->size(), 1); - CHECK(res.getObject()->at("test")->getObject()->at("test")->is_type( - io::JSONNode::Type::STRING)); - - raw = std::string("{\"test\":[1, 2]}"); - res = io::parseJson(raw.c_str(), raw.size()); - CHECK(res.is_type(io::JSONNode::Type::OBJECT)); - CHECK_EQ(res.getObject()->size(), 1); - CHECK(res.getObject()->at("test")->is_type(io::JSONNode::Type::LIST)); - CHECK_EQ(res.getObject()->at("test")->getList()->size(), 2); - CHECK(res.getObject()->at("test")->getList()->at(0)->is_type( - io::JSONNode::Type::NUMBER)); - CHECK_EQ(res.getObject()->at("test")->getList()->at(0)->getNumber(), 1); - CHECK(res.getObject()->at("test")->getList()->at(1)->is_type( - io::JSONNode::Type::NUMBER)); - CHECK_EQ(res.getObject()->at("test")->getList()->at(1)->getNumber(), 2); raw = std::string( "{\"test\":{\"dtype\":\"F32\",\"shape\":[4], \"data_offsets\":[0, 16]}}"); res = io::parseJson(raw.c_str(), raw.size());