Remove unnecessary string copies (#891)

1. Use string_view instead of string when there is no need for copy.
2. Otherwise move string when possible.
This commit is contained in:
Cheng 2024-03-29 05:14:59 +09:00 committed by GitHub
parent 45f636e759
commit 46caf0bef0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 36 additions and 41 deletions

View File

@ -190,9 +190,9 @@ std::string dtype_to_array_protocol(const Dtype& t) {
} }
// Dtype from array protocol type string // Dtype from array protocol type string
Dtype dtype_from_array_protocol(const std::string& t) { Dtype dtype_from_array_protocol(std::string_view t) {
if (t.length() == 2 || t.length() == 3) { if (t.length() == 2 || t.length() == 3) {
std::string r = t.length() == 3 ? t.substr(1, 2) : t; std::string_view r = t.length() == 3 ? t.substr(1, 2) : t;
if (r == "V2") { if (r == "V2") {
return bfloat16; return bfloat16;
@ -238,7 +238,7 @@ Dtype dtype_from_array_protocol(const std::string& t) {
} }
throw std::invalid_argument( throw std::invalid_argument(
"[from_str] Invalid array protocol type-string: " + t); "[from_str] Invalid array protocol type-string: " + std::string(t));
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -106,6 +106,6 @@ struct TypeToDtype {
// Array protocol typestring for Dtype // Array protocol typestring for Dtype
std::string dtype_to_array_protocol(const Dtype& t); std::string dtype_to_array_protocol(const Dtype& t);
// Dtype from array protocol type string // Dtype from array protocol type string
Dtype dtype_from_array_protocol(const std::string& t); Dtype dtype_from_array_protocol(std::string_view t);
} // namespace mlx::core } // namespace mlx::core

View File

@ -23,8 +23,7 @@ const std::string& NodeNamer::get_name(const array& x) {
letters.push_back('A' + (var_num - 1) % 26); letters.push_back('A' + (var_num - 1) % 26);
var_num = (var_num - 1) / 26; var_num = (var_num - 1) / 26;
} }
std::string name(letters.rbegin(), letters.rend()); names.emplace(x.id(), std::string(letters.rbegin(), letters.rend()));
names.insert({x.id(), name});
return get_name(x); return get_name(x);
} }

View File

@ -23,13 +23,13 @@ using SafetensorsLoad = std::pair<
void save(std::shared_ptr<io::Writer> out_stream, array a); void save(std::shared_ptr<io::Writer> out_stream, array a);
/** Save array to file in .npy format */ /** Save array to file in .npy format */
void save(const std::string& file, array a); void save(std::string file, array a);
/** Load array from reader in .npy format */ /** Load array from reader in .npy format */
array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {}); array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
/** Load array from file in .npy format */ /** Load array from file in .npy format */
array load(const std::string& file, StreamOrDevice s = {}); array load(std::string file, StreamOrDevice s = {});
/** Load array map from .safetensors file format */ /** Load array map from .safetensors file format */
SafetensorsLoad load_safetensors( SafetensorsLoad load_safetensors(
@ -44,13 +44,13 @@ void save_safetensors(
std::unordered_map<std::string, array>, std::unordered_map<std::string, array>,
std::unordered_map<std::string, std::string> metadata = {}); std::unordered_map<std::string, std::string> metadata = {});
void save_safetensors( void save_safetensors(
const std::string& file, std::string file,
std::unordered_map<std::string, array>, std::unordered_map<std::string, array>,
std::unordered_map<std::string, std::string> metadata = {}); std::unordered_map<std::string, std::string> metadata = {});
/** Load array map and metadata from .gguf file format */ /** Load array map and metadata from .gguf file format */
GGUFLoad load_gguf(const std::string& file, StreamOrDevice s = {}); GGUFLoad load_gguf(std::string_view file, StreamOrDevice s = {});
void save_gguf( void save_gguf(
std::string file, std::string file,

View File

@ -230,8 +230,8 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
return array_map; return array_map;
} }
GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) { GGUFLoad load_gguf(std::string_view file, StreamOrDevice s) {
gguf_ctx* ctx = gguf_open(file.c_str()); gguf_ctx* ctx = gguf_open(file.data());
if (!ctx) { if (!ctx) {
throw std::runtime_error("[load_gguf] gguf_init failed"); throw std::runtime_error("[load_gguf] gguf_init failed");
} }

View File

@ -105,7 +105,7 @@ void gguf_load_quantized(
weights_per_byte = 1; weights_per_byte = 1;
} }
std::string name = std::string(tensor.name, tensor.namelen); std::string name(tensor.name, tensor.namelen);
std::vector<int> shape = get_shape(tensor); std::vector<int> shape = get_shape(tensor);
const uint64_t weights_per_block = 32; const uint64_t weights_per_block = 32;
if (shape[shape.size() - 1] % weights_per_block != 0) { if (shape[shape.size() - 1] % weights_per_block != 0) {
@ -136,9 +136,9 @@ void gguf_load_quantized(
extract_q8_0_data(tensor, weights, scales, biases); extract_q8_0_data(tensor, weights, scales, biases);
} }
a.insert({name, weights}); a.emplace(std::move(name), std::move(weights));
auto check_insert = [](auto inserted) { auto check_insert = [](const auto& inserted) {
if (!inserted.second) { if (!inserted.second) {
std::ostringstream msg; std::ostringstream msg;
msg << "[load_gguf] Duplicate parameter name " << inserted.first->second msg << "[load_gguf] Duplicate parameter name " << inserted.first->second
@ -147,11 +147,11 @@ void gguf_load_quantized(
} }
}; };
const std::string weight_suffix = ".weight"; constexpr std::string_view weight_suffix = ".weight";
const std::string name_prefix = const std::string name_prefix =
name.substr(0, name.length() - weight_suffix.length()); name.substr(0, name.length() - weight_suffix.length());
check_insert(a.insert({name_prefix + ".scales", scales})); check_insert(a.emplace(name_prefix + ".scales", std::move(scales)));
check_insert(a.insert({name_prefix + ".biases", biases})); check_insert(a.emplace(name_prefix + ".biases", std::move(biases)));
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -114,16 +114,13 @@ void save(std::shared_ptr<io::Writer> out_stream, array a) {
} }
/** Save array to file in .npy format */ /** Save array to file in .npy format */
void save(const std::string& file_, array a) { void save(std::string file, array a) {
// Open and check file
std::string file = file_;
// Add .npy to file name if it is not there // Add .npy to file name if it is not there
if (file.length() < 4 || file.substr(file.length() - 4, 4) != ".npy") if (file.length() < 4 || file.substr(file.length() - 4, 4) != ".npy")
file += ".npy"; file += ".npy";
// Serialize array // Serialize array
save(std::make_shared<io::FileWriter>(file), a); save(std::make_shared<io::FileWriter>(std::move(file)), a);
} }
/** Load array from reader in .npy format */ /** Load array from reader in .npy format */
@ -227,8 +224,8 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
} }
/** Load array from file in .npy format */ /** Load array from file in .npy format */
array load(const std::string& file, StreamOrDevice s) { array load(std::string file, StreamOrDevice s) {
return load(std::make_shared<io::FileReader>(file), s); return load(std::make_shared<io::FileReader>(std::move(file)), s);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -60,7 +60,7 @@ std::string dtype_to_safetensor_str(Dtype t) {
} }
} }
Dtype dtype_from_safetensor_str(std::string str) { Dtype dtype_from_safetensor_str(std::string_view str) {
if (str == ST_F32) { if (str == ST_F32) {
return float32; return float32;
} else if (str == ST_F16) { } else if (str == ST_F16) {
@ -88,7 +88,8 @@ Dtype dtype_from_safetensor_str(std::string str) {
} else if (str == ST_C64) { } else if (str == ST_C64) {
return complex64; return complex64;
} else { } else {
throw std::runtime_error("[safetensor] unsupported dtype " + str); throw std::runtime_error(
"[safetensor] unsupported dtype " + std::string(str));
} }
} }
@ -129,9 +130,9 @@ SafetensorsLoad load_safetensors(
} }
continue; continue;
} }
std::string dtype = item.value().at("dtype"); const std::string& dtype = item.value().at("dtype");
std::vector<int> shape = item.value().at("shape"); const std::vector<int>& shape = item.value().at("shape");
std::vector<size_t> data_offsets = item.value().at("data_offsets"); const std::vector<size_t>& data_offsets = item.value().at("data_offsets");
Dtype type = dtype_from_safetensor_str(dtype); Dtype type = dtype_from_safetensor_str(dtype);
auto loaded_array = array( auto loaded_array = array(
shape, shape,
@ -207,19 +208,17 @@ void save_safetensors(
} }
void save_safetensors( void save_safetensors(
const std::string& file_, std::string file,
std::unordered_map<std::string, array> a, std::unordered_map<std::string, array> a,
std::unordered_map<std::string, std::string> metadata /* = {} */) { std::unordered_map<std::string, std::string> metadata /* = {} */) {
// Open and check file
std::string file = file_;
// Add .safetensors to file name if it is not there // Add .safetensors to file name if it is not there
if (file.length() < 12 || if (file.length() < 12 ||
file.substr(file.length() - 12, 12) != ".safetensors") file.substr(file.length() - 12, 12) != ".safetensors")
file += ".safetensors"; file += ".safetensors";
// Serialize array // Serialize array
save_safetensors(std::make_shared<io::FileWriter>(file), a, metadata); save_safetensors(
std::make_shared<io::FileWriter>(std::move(file)), a, metadata);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -96,7 +96,7 @@ inline array matrix_norm(
inline array matrix_norm( inline array matrix_norm(
const array& a, const array& a,
const std::string& ord, std::string_view ord,
const std::vector<int>& axis, const std::vector<int>& axis,
bool keepdims, bool keepdims,
StreamOrDevice s) { StreamOrDevice s) {
@ -153,7 +153,7 @@ array norm(
array norm( array norm(
const array& a, const array& a,
const std::string& ord, std::string_view ord,
const std::optional<std::vector<int>>& axis /* = std::nullopt */, const std::optional<std::vector<int>>& axis /* = std::nullopt */,
bool keepdims /* = false */, bool keepdims /* = false */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {

View File

@ -38,13 +38,13 @@ inline array norm(
} }
array norm( array norm(
const array& a, const array& a,
const std::string& ord, std::string_view ord,
const std::optional<std::vector<int>>& axis = std::nullopt, const std::optional<std::vector<int>>& axis = std::nullopt,
bool keepdims = false, bool keepdims = false,
StreamOrDevice s = {}); StreamOrDevice s = {});
inline array norm( inline array norm(
const array& a, const array& a,
const std::string& ord, std::string_view ord,
int axis, int axis,
bool keepdims = false, bool keepdims = false,
StreamOrDevice s = {}) { StreamOrDevice s = {}) {

View File

@ -60,10 +60,10 @@ struct buffer_info {
std::vector<ssize_t> strides; std::vector<ssize_t> strides;
buffer_info( buffer_info(
const std::string& format, std::string format,
std::vector<ssize_t> shape_in, std::vector<ssize_t> shape_in,
std::vector<ssize_t> strides_in) std::vector<ssize_t> strides_in)
: format(format), : format(std::move(format)),
shape(std::move(shape_in)), shape(std::move(shape_in)),
strides(std::move(strides_in)) {} strides(std::move(strides_in)) {}