mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-27 20:07:59 +08:00
Remove unnecessary string copies (#891)
1. Use string_view instead of string when there is no need for copy. 2. Otherwise move string when possible.
This commit is contained in:
parent
45f636e759
commit
46caf0bef0
@ -190,9 +190,9 @@ std::string dtype_to_array_protocol(const Dtype& t) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Dtype from array protocol type string
|
// Dtype from array protocol type string
|
||||||
Dtype dtype_from_array_protocol(const std::string& t) {
|
Dtype dtype_from_array_protocol(std::string_view t) {
|
||||||
if (t.length() == 2 || t.length() == 3) {
|
if (t.length() == 2 || t.length() == 3) {
|
||||||
std::string r = t.length() == 3 ? t.substr(1, 2) : t;
|
std::string_view r = t.length() == 3 ? t.substr(1, 2) : t;
|
||||||
|
|
||||||
if (r == "V2") {
|
if (r == "V2") {
|
||||||
return bfloat16;
|
return bfloat16;
|
||||||
@ -238,7 +238,7 @@ Dtype dtype_from_array_protocol(const std::string& t) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[from_str] Invalid array protocol type-string: " + t);
|
"[from_str] Invalid array protocol type-string: " + std::string(t));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -106,6 +106,6 @@ struct TypeToDtype {
|
|||||||
// Array protocol typestring for Dtype
|
// Array protocol typestring for Dtype
|
||||||
std::string dtype_to_array_protocol(const Dtype& t);
|
std::string dtype_to_array_protocol(const Dtype& t);
|
||||||
// Dtype from array protocol type string
|
// Dtype from array protocol type string
|
||||||
Dtype dtype_from_array_protocol(const std::string& t);
|
Dtype dtype_from_array_protocol(std::string_view t);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -23,8 +23,7 @@ const std::string& NodeNamer::get_name(const array& x) {
|
|||||||
letters.push_back('A' + (var_num - 1) % 26);
|
letters.push_back('A' + (var_num - 1) % 26);
|
||||||
var_num = (var_num - 1) / 26;
|
var_num = (var_num - 1) / 26;
|
||||||
}
|
}
|
||||||
std::string name(letters.rbegin(), letters.rend());
|
names.emplace(x.id(), std::string(letters.rbegin(), letters.rend()));
|
||||||
names.insert({x.id(), name});
|
|
||||||
|
|
||||||
return get_name(x);
|
return get_name(x);
|
||||||
}
|
}
|
||||||
|
8
mlx/io.h
8
mlx/io.h
@ -23,13 +23,13 @@ using SafetensorsLoad = std::pair<
|
|||||||
void save(std::shared_ptr<io::Writer> out_stream, array a);
|
void save(std::shared_ptr<io::Writer> out_stream, array a);
|
||||||
|
|
||||||
/** Save array to file in .npy format */
|
/** Save array to file in .npy format */
|
||||||
void save(const std::string& file, array a);
|
void save(std::string file, array a);
|
||||||
|
|
||||||
/** Load array from reader in .npy format */
|
/** Load array from reader in .npy format */
|
||||||
array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
|
array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Load array from file in .npy format */
|
/** Load array from file in .npy format */
|
||||||
array load(const std::string& file, StreamOrDevice s = {});
|
array load(std::string file, StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Load array map from .safetensors file format */
|
/** Load array map from .safetensors file format */
|
||||||
SafetensorsLoad load_safetensors(
|
SafetensorsLoad load_safetensors(
|
||||||
@ -44,13 +44,13 @@ void save_safetensors(
|
|||||||
std::unordered_map<std::string, array>,
|
std::unordered_map<std::string, array>,
|
||||||
std::unordered_map<std::string, std::string> metadata = {});
|
std::unordered_map<std::string, std::string> metadata = {});
|
||||||
void save_safetensors(
|
void save_safetensors(
|
||||||
const std::string& file,
|
std::string file,
|
||||||
std::unordered_map<std::string, array>,
|
std::unordered_map<std::string, array>,
|
||||||
std::unordered_map<std::string, std::string> metadata = {});
|
std::unordered_map<std::string, std::string> metadata = {});
|
||||||
|
|
||||||
/** Load array map and metadata from .gguf file format */
|
/** Load array map and metadata from .gguf file format */
|
||||||
|
|
||||||
GGUFLoad load_gguf(const std::string& file, StreamOrDevice s = {});
|
GGUFLoad load_gguf(std::string_view file, StreamOrDevice s = {});
|
||||||
|
|
||||||
void save_gguf(
|
void save_gguf(
|
||||||
std::string file,
|
std::string file,
|
||||||
|
@ -230,8 +230,8 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
|
|||||||
return array_map;
|
return array_map;
|
||||||
}
|
}
|
||||||
|
|
||||||
GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) {
|
GGUFLoad load_gguf(std::string_view file, StreamOrDevice s) {
|
||||||
gguf_ctx* ctx = gguf_open(file.c_str());
|
gguf_ctx* ctx = gguf_open(file.data());
|
||||||
if (!ctx) {
|
if (!ctx) {
|
||||||
throw std::runtime_error("[load_gguf] gguf_init failed");
|
throw std::runtime_error("[load_gguf] gguf_init failed");
|
||||||
}
|
}
|
||||||
|
@ -105,7 +105,7 @@ void gguf_load_quantized(
|
|||||||
weights_per_byte = 1;
|
weights_per_byte = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string name = std::string(tensor.name, tensor.namelen);
|
std::string name(tensor.name, tensor.namelen);
|
||||||
std::vector<int> shape = get_shape(tensor);
|
std::vector<int> shape = get_shape(tensor);
|
||||||
const uint64_t weights_per_block = 32;
|
const uint64_t weights_per_block = 32;
|
||||||
if (shape[shape.size() - 1] % weights_per_block != 0) {
|
if (shape[shape.size() - 1] % weights_per_block != 0) {
|
||||||
@ -136,9 +136,9 @@ void gguf_load_quantized(
|
|||||||
extract_q8_0_data(tensor, weights, scales, biases);
|
extract_q8_0_data(tensor, weights, scales, biases);
|
||||||
}
|
}
|
||||||
|
|
||||||
a.insert({name, weights});
|
a.emplace(std::move(name), std::move(weights));
|
||||||
|
|
||||||
auto check_insert = [](auto inserted) {
|
auto check_insert = [](const auto& inserted) {
|
||||||
if (!inserted.second) {
|
if (!inserted.second) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[load_gguf] Duplicate parameter name " << inserted.first->second
|
msg << "[load_gguf] Duplicate parameter name " << inserted.first->second
|
||||||
@ -147,11 +147,11 @@ void gguf_load_quantized(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const std::string weight_suffix = ".weight";
|
constexpr std::string_view weight_suffix = ".weight";
|
||||||
const std::string name_prefix =
|
const std::string name_prefix =
|
||||||
name.substr(0, name.length() - weight_suffix.length());
|
name.substr(0, name.length() - weight_suffix.length());
|
||||||
check_insert(a.insert({name_prefix + ".scales", scales}));
|
check_insert(a.emplace(name_prefix + ".scales", std::move(scales)));
|
||||||
check_insert(a.insert({name_prefix + ".biases", biases}));
|
check_insert(a.emplace(name_prefix + ".biases", std::move(biases)));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -114,16 +114,13 @@ void save(std::shared_ptr<io::Writer> out_stream, array a) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/** Save array to file in .npy format */
|
/** Save array to file in .npy format */
|
||||||
void save(const std::string& file_, array a) {
|
void save(std::string file, array a) {
|
||||||
// Open and check file
|
|
||||||
std::string file = file_;
|
|
||||||
|
|
||||||
// Add .npy to file name if it is not there
|
// Add .npy to file name if it is not there
|
||||||
if (file.length() < 4 || file.substr(file.length() - 4, 4) != ".npy")
|
if (file.length() < 4 || file.substr(file.length() - 4, 4) != ".npy")
|
||||||
file += ".npy";
|
file += ".npy";
|
||||||
|
|
||||||
// Serialize array
|
// Serialize array
|
||||||
save(std::make_shared<io::FileWriter>(file), a);
|
save(std::make_shared<io::FileWriter>(std::move(file)), a);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Load array from reader in .npy format */
|
/** Load array from reader in .npy format */
|
||||||
@ -227,8 +224,8 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/** Load array from file in .npy format */
|
/** Load array from file in .npy format */
|
||||||
array load(const std::string& file, StreamOrDevice s) {
|
array load(std::string file, StreamOrDevice s) {
|
||||||
return load(std::make_shared<io::FileReader>(file), s);
|
return load(std::make_shared<io::FileReader>(std::move(file)), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -60,7 +60,7 @@ std::string dtype_to_safetensor_str(Dtype t) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Dtype dtype_from_safetensor_str(std::string str) {
|
Dtype dtype_from_safetensor_str(std::string_view str) {
|
||||||
if (str == ST_F32) {
|
if (str == ST_F32) {
|
||||||
return float32;
|
return float32;
|
||||||
} else if (str == ST_F16) {
|
} else if (str == ST_F16) {
|
||||||
@ -88,7 +88,8 @@ Dtype dtype_from_safetensor_str(std::string str) {
|
|||||||
} else if (str == ST_C64) {
|
} else if (str == ST_C64) {
|
||||||
return complex64;
|
return complex64;
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error("[safetensor] unsupported dtype " + str);
|
throw std::runtime_error(
|
||||||
|
"[safetensor] unsupported dtype " + std::string(str));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -129,9 +130,9 @@ SafetensorsLoad load_safetensors(
|
|||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
std::string dtype = item.value().at("dtype");
|
const std::string& dtype = item.value().at("dtype");
|
||||||
std::vector<int> shape = item.value().at("shape");
|
const std::vector<int>& shape = item.value().at("shape");
|
||||||
std::vector<size_t> data_offsets = item.value().at("data_offsets");
|
const std::vector<size_t>& data_offsets = item.value().at("data_offsets");
|
||||||
Dtype type = dtype_from_safetensor_str(dtype);
|
Dtype type = dtype_from_safetensor_str(dtype);
|
||||||
auto loaded_array = array(
|
auto loaded_array = array(
|
||||||
shape,
|
shape,
|
||||||
@ -207,19 +208,17 @@ void save_safetensors(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void save_safetensors(
|
void save_safetensors(
|
||||||
const std::string& file_,
|
std::string file,
|
||||||
std::unordered_map<std::string, array> a,
|
std::unordered_map<std::string, array> a,
|
||||||
std::unordered_map<std::string, std::string> metadata /* = {} */) {
|
std::unordered_map<std::string, std::string> metadata /* = {} */) {
|
||||||
// Open and check file
|
|
||||||
std::string file = file_;
|
|
||||||
|
|
||||||
// Add .safetensors to file name if it is not there
|
// Add .safetensors to file name if it is not there
|
||||||
if (file.length() < 12 ||
|
if (file.length() < 12 ||
|
||||||
file.substr(file.length() - 12, 12) != ".safetensors")
|
file.substr(file.length() - 12, 12) != ".safetensors")
|
||||||
file += ".safetensors";
|
file += ".safetensors";
|
||||||
|
|
||||||
// Serialize array
|
// Serialize array
|
||||||
save_safetensors(std::make_shared<io::FileWriter>(file), a, metadata);
|
save_safetensors(
|
||||||
|
std::make_shared<io::FileWriter>(std::move(file)), a, metadata);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -96,7 +96,7 @@ inline array matrix_norm(
|
|||||||
|
|
||||||
inline array matrix_norm(
|
inline array matrix_norm(
|
||||||
const array& a,
|
const array& a,
|
||||||
const std::string& ord,
|
std::string_view ord,
|
||||||
const std::vector<int>& axis,
|
const std::vector<int>& axis,
|
||||||
bool keepdims,
|
bool keepdims,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
@ -153,7 +153,7 @@ array norm(
|
|||||||
|
|
||||||
array norm(
|
array norm(
|
||||||
const array& a,
|
const array& a,
|
||||||
const std::string& ord,
|
std::string_view ord,
|
||||||
const std::optional<std::vector<int>>& axis /* = std::nullopt */,
|
const std::optional<std::vector<int>>& axis /* = std::nullopt */,
|
||||||
bool keepdims /* = false */,
|
bool keepdims /* = false */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
|
@ -38,13 +38,13 @@ inline array norm(
|
|||||||
}
|
}
|
||||||
array norm(
|
array norm(
|
||||||
const array& a,
|
const array& a,
|
||||||
const std::string& ord,
|
std::string_view ord,
|
||||||
const std::optional<std::vector<int>>& axis = std::nullopt,
|
const std::optional<std::vector<int>>& axis = std::nullopt,
|
||||||
bool keepdims = false,
|
bool keepdims = false,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
inline array norm(
|
inline array norm(
|
||||||
const array& a,
|
const array& a,
|
||||||
const std::string& ord,
|
std::string_view ord,
|
||||||
int axis,
|
int axis,
|
||||||
bool keepdims = false,
|
bool keepdims = false,
|
||||||
StreamOrDevice s = {}) {
|
StreamOrDevice s = {}) {
|
||||||
|
@ -60,10 +60,10 @@ struct buffer_info {
|
|||||||
std::vector<ssize_t> strides;
|
std::vector<ssize_t> strides;
|
||||||
|
|
||||||
buffer_info(
|
buffer_info(
|
||||||
const std::string& format,
|
std::string format,
|
||||||
std::vector<ssize_t> shape_in,
|
std::vector<ssize_t> shape_in,
|
||||||
std::vector<ssize_t> strides_in)
|
std::vector<ssize_t> strides_in)
|
||||||
: format(format),
|
: format(std::move(format)),
|
||||||
shape(std::move(shape_in)),
|
shape(std::move(shape_in)),
|
||||||
strides(std::move(strides_in)) {}
|
strides(std::move(strides_in)) {}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user