Remove unnecessary string copies (#891)

1. Use string_view instead of string when there is no need for copy.
2. Otherwise move string when possible.
This commit is contained in:
Cheng 2024-03-29 05:14:59 +09:00 committed by GitHub
parent 45f636e759
commit 46caf0bef0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 36 additions and 41 deletions

View File

@ -190,9 +190,9 @@ std::string dtype_to_array_protocol(const Dtype& t) {
}
// Dtype from array protocol type string
Dtype dtype_from_array_protocol(const std::string& t) {
Dtype dtype_from_array_protocol(std::string_view t) {
if (t.length() == 2 || t.length() == 3) {
std::string r = t.length() == 3 ? t.substr(1, 2) : t;
std::string_view r = t.length() == 3 ? t.substr(1, 2) : t;
if (r == "V2") {
return bfloat16;
@ -238,7 +238,7 @@ Dtype dtype_from_array_protocol(const std::string& t) {
}
throw std::invalid_argument(
"[from_str] Invalid array protocol type-string: " + t);
"[from_str] Invalid array protocol type-string: " + std::string(t));
}
} // namespace mlx::core

View File

@ -106,6 +106,6 @@ struct TypeToDtype {
// Array protocol typestring for Dtype
std::string dtype_to_array_protocol(const Dtype& t);
// Dtype from array protocol type string
Dtype dtype_from_array_protocol(const std::string& t);
Dtype dtype_from_array_protocol(std::string_view t);
} // namespace mlx::core

View File

@ -23,8 +23,7 @@ const std::string& NodeNamer::get_name(const array& x) {
letters.push_back('A' + (var_num - 1) % 26);
var_num = (var_num - 1) / 26;
}
std::string name(letters.rbegin(), letters.rend());
names.insert({x.id(), name});
names.emplace(x.id(), std::string(letters.rbegin(), letters.rend()));
return get_name(x);
}

View File

@ -23,13 +23,13 @@ using SafetensorsLoad = std::pair<
void save(std::shared_ptr<io::Writer> out_stream, array a);
/** Save array to file in .npy format */
void save(const std::string& file, array a);
void save(std::string file, array a);
/** Load array from reader in .npy format */
array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
/** Load array from file in .npy format */
array load(const std::string& file, StreamOrDevice s = {});
array load(std::string file, StreamOrDevice s = {});
/** Load array map from .safetensors file format */
SafetensorsLoad load_safetensors(
@ -44,13 +44,13 @@ void save_safetensors(
std::unordered_map<std::string, array>,
std::unordered_map<std::string, std::string> metadata = {});
void save_safetensors(
const std::string& file,
std::string file,
std::unordered_map<std::string, array>,
std::unordered_map<std::string, std::string> metadata = {});
/** Load array map and metadata from .gguf file format */
GGUFLoad load_gguf(const std::string& file, StreamOrDevice s = {});
GGUFLoad load_gguf(std::string_view file, StreamOrDevice s = {});
void save_gguf(
std::string file,

View File

@ -230,8 +230,8 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
return array_map;
}
GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) {
gguf_ctx* ctx = gguf_open(file.c_str());
GGUFLoad load_gguf(std::string_view file, StreamOrDevice s) {
gguf_ctx* ctx = gguf_open(file.data());
if (!ctx) {
throw std::runtime_error("[load_gguf] gguf_init failed");
}

View File

@ -105,7 +105,7 @@ void gguf_load_quantized(
weights_per_byte = 1;
}
std::string name = std::string(tensor.name, tensor.namelen);
std::string name(tensor.name, tensor.namelen);
std::vector<int> shape = get_shape(tensor);
const uint64_t weights_per_block = 32;
if (shape[shape.size() - 1] % weights_per_block != 0) {
@ -136,9 +136,9 @@ void gguf_load_quantized(
extract_q8_0_data(tensor, weights, scales, biases);
}
a.insert({name, weights});
a.emplace(std::move(name), std::move(weights));
auto check_insert = [](auto inserted) {
auto check_insert = [](const auto& inserted) {
if (!inserted.second) {
std::ostringstream msg;
msg << "[load_gguf] Duplicate parameter name " << inserted.first->second
@ -147,11 +147,11 @@ void gguf_load_quantized(
}
};
const std::string weight_suffix = ".weight";
constexpr std::string_view weight_suffix = ".weight";
const std::string name_prefix =
name.substr(0, name.length() - weight_suffix.length());
check_insert(a.insert({name_prefix + ".scales", scales}));
check_insert(a.insert({name_prefix + ".biases", biases}));
check_insert(a.emplace(name_prefix + ".scales", std::move(scales)));
check_insert(a.emplace(name_prefix + ".biases", std::move(biases)));
}
} // namespace mlx::core

View File

@ -114,16 +114,13 @@ void save(std::shared_ptr<io::Writer> out_stream, array a) {
}
/** Save array to file in .npy format */
void save(const std::string& file_, array a) {
// Open and check file
std::string file = file_;
void save(std::string file, array a) {
// Add .npy to file name if it is not there
if (file.length() < 4 || file.substr(file.length() - 4, 4) != ".npy")
file += ".npy";
// Serialize array
save(std::make_shared<io::FileWriter>(file), a);
save(std::make_shared<io::FileWriter>(std::move(file)), a);
}
/** Load array from reader in .npy format */
@ -227,8 +224,8 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
}
/** Load array from file in .npy format */
array load(const std::string& file, StreamOrDevice s) {
return load(std::make_shared<io::FileReader>(file), s);
array load(std::string file, StreamOrDevice s) {
return load(std::make_shared<io::FileReader>(std::move(file)), s);
}
} // namespace mlx::core

View File

@ -60,7 +60,7 @@ std::string dtype_to_safetensor_str(Dtype t) {
}
}
Dtype dtype_from_safetensor_str(std::string str) {
Dtype dtype_from_safetensor_str(std::string_view str) {
if (str == ST_F32) {
return float32;
} else if (str == ST_F16) {
@ -88,7 +88,8 @@ Dtype dtype_from_safetensor_str(std::string str) {
} else if (str == ST_C64) {
return complex64;
} else {
throw std::runtime_error("[safetensor] unsupported dtype " + str);
throw std::runtime_error(
"[safetensor] unsupported dtype " + std::string(str));
}
}
@ -129,9 +130,9 @@ SafetensorsLoad load_safetensors(
}
continue;
}
std::string dtype = item.value().at("dtype");
std::vector<int> shape = item.value().at("shape");
std::vector<size_t> data_offsets = item.value().at("data_offsets");
const std::string& dtype = item.value().at("dtype");
const std::vector<int>& shape = item.value().at("shape");
const std::vector<size_t>& data_offsets = item.value().at("data_offsets");
Dtype type = dtype_from_safetensor_str(dtype);
auto loaded_array = array(
shape,
@ -207,19 +208,17 @@ void save_safetensors(
}
void save_safetensors(
const std::string& file_,
std::string file,
std::unordered_map<std::string, array> a,
std::unordered_map<std::string, std::string> metadata /* = {} */) {
// Open and check file
std::string file = file_;
// Add .safetensors to file name if it is not there
if (file.length() < 12 ||
file.substr(file.length() - 12, 12) != ".safetensors")
file += ".safetensors";
// Serialize array
save_safetensors(std::make_shared<io::FileWriter>(file), a, metadata);
save_safetensors(
std::make_shared<io::FileWriter>(std::move(file)), a, metadata);
}
} // namespace mlx::core

View File

@ -96,7 +96,7 @@ inline array matrix_norm(
inline array matrix_norm(
const array& a,
const std::string& ord,
std::string_view ord,
const std::vector<int>& axis,
bool keepdims,
StreamOrDevice s) {
@ -153,7 +153,7 @@ array norm(
array norm(
const array& a,
const std::string& ord,
std::string_view ord,
const std::optional<std::vector<int>>& axis /* = std::nullopt */,
bool keepdims /* = false */,
StreamOrDevice s /* = {} */) {

View File

@ -38,13 +38,13 @@ inline array norm(
}
array norm(
const array& a,
const std::string& ord,
std::string_view ord,
const std::optional<std::vector<int>>& axis = std::nullopt,
bool keepdims = false,
StreamOrDevice s = {});
inline array norm(
const array& a,
const std::string& ord,
std::string_view ord,
int axis,
bool keepdims = false,
StreamOrDevice s = {}) {

View File

@ -60,10 +60,10 @@ struct buffer_info {
std::vector<ssize_t> strides;
buffer_info(
const std::string& format,
std::string format,
std::vector<ssize_t> shape_in,
std::vector<ssize_t> strides_in)
: format(format),
: format(std::move(format)),
shape(std::move(shape_in)),
strides(std::move(strides_in)) {}