diff --git a/mlx/dtype.cpp b/mlx/dtype.cpp index 5eaf7c90e..465e51bd5 100644 --- a/mlx/dtype.cpp +++ b/mlx/dtype.cpp @@ -190,9 +190,9 @@ std::string dtype_to_array_protocol(const Dtype& t) { } // Dtype from array protocol type string -Dtype dtype_from_array_protocol(const std::string& t) { +Dtype dtype_from_array_protocol(std::string_view t) { if (t.length() == 2 || t.length() == 3) { - std::string r = t.length() == 3 ? t.substr(1, 2) : t; + std::string_view r = t.length() == 3 ? t.substr(1, 2) : t; if (r == "V2") { return bfloat16; @@ -238,7 +238,7 @@ Dtype dtype_from_array_protocol(const std::string& t) { } throw std::invalid_argument( - "[from_str] Invalid array protocol type-string: " + t); + "[from_str] Invalid array protocol type-string: " + std::string(t)); } } // namespace mlx::core diff --git a/mlx/dtype.h b/mlx/dtype.h index 410b70fb1..007b09d74 100644 --- a/mlx/dtype.h +++ b/mlx/dtype.h @@ -106,6 +106,6 @@ struct TypeToDtype { // Array protocol typestring for Dtype std::string dtype_to_array_protocol(const Dtype& t); // Dtype from array protocol type string -Dtype dtype_from_array_protocol(const std::string& t); +Dtype dtype_from_array_protocol(std::string_view t); } // namespace mlx::core diff --git a/mlx/graph_utils.cpp b/mlx/graph_utils.cpp index ba031c441..c6b594247 100644 --- a/mlx/graph_utils.cpp +++ b/mlx/graph_utils.cpp @@ -23,8 +23,7 @@ const std::string& NodeNamer::get_name(const array& x) { letters.push_back('A' + (var_num - 1) % 26); var_num = (var_num - 1) / 26; } - std::string name(letters.rbegin(), letters.rend()); - names.insert({x.id(), name}); + names.emplace(x.id(), std::string(letters.rbegin(), letters.rend())); return get_name(x); } diff --git a/mlx/io.h b/mlx/io.h index 59866ea27..640f62f68 100644 --- a/mlx/io.h +++ b/mlx/io.h @@ -23,13 +23,13 @@ using SafetensorsLoad = std::pair< void save(std::shared_ptr out_stream, array a); /** Save array to file in .npy format */ -void save(const std::string& file, array a); +void save(std::string file, array a); /** Load array from reader in .npy format */ array load(std::shared_ptr in_stream, StreamOrDevice s = {}); /** Load array from file in .npy format */ -array load(const std::string& file, StreamOrDevice s = {}); +array load(std::string file, StreamOrDevice s = {}); /** Load array map from .safetensors file format */ SafetensorsLoad load_safetensors( @@ -44,13 +44,13 @@ void save_safetensors( std::unordered_map, std::unordered_map metadata = {}); void save_safetensors( - const std::string& file, + std::string file, std::unordered_map, std::unordered_map metadata = {}); /** Load array map and metadata from .gguf file format */ -GGUFLoad load_gguf(const std::string& file, StreamOrDevice s = {}); +GGUFLoad load_gguf(std::string_view file, StreamOrDevice s = {}); void save_gguf( std::string file, diff --git a/mlx/io/gguf.cpp b/mlx/io/gguf.cpp index 9e7953d6e..e23fb255e 100644 --- a/mlx/io/gguf.cpp +++ b/mlx/io/gguf.cpp @@ -230,8 +230,8 @@ std::unordered_map load_arrays(gguf_ctx* ctx) { return array_map; } -GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) { - gguf_ctx* ctx = gguf_open(file.c_str()); +GGUFLoad load_gguf(std::string_view file, StreamOrDevice s) { + gguf_ctx* ctx = gguf_open(file.data()); if (!ctx) { throw std::runtime_error("[load_gguf] gguf_init failed"); } diff --git a/mlx/io/gguf_quants.cpp b/mlx/io/gguf_quants.cpp index b9fe1e3bf..86ef960d5 100644 --- a/mlx/io/gguf_quants.cpp +++ b/mlx/io/gguf_quants.cpp @@ -105,7 +105,7 @@ void gguf_load_quantized( weights_per_byte = 1; } - std::string name = std::string(tensor.name, tensor.namelen); + std::string name(tensor.name, tensor.namelen); std::vector shape = get_shape(tensor); const uint64_t weights_per_block = 32; if (shape[shape.size() - 1] % weights_per_block != 0) { @@ -136,9 +136,9 @@ void gguf_load_quantized( extract_q8_0_data(tensor, weights, scales, biases); } - a.insert({name, weights}); + a.emplace(std::move(name), std::move(weights)); - auto check_insert = [](auto inserted) { + auto check_insert = [](const auto& inserted) { if (!inserted.second) { std::ostringstream msg; msg << "[load_gguf] Duplicate parameter name " << inserted.first->second @@ -147,11 +147,11 @@ void gguf_load_quantized( } }; - const std::string weight_suffix = ".weight"; + constexpr std::string_view weight_suffix = ".weight"; const std::string name_prefix = name.substr(0, name.length() - weight_suffix.length()); - check_insert(a.insert({name_prefix + ".scales", scales})); - check_insert(a.insert({name_prefix + ".biases", biases})); + check_insert(a.emplace(name_prefix + ".scales", std::move(scales))); + check_insert(a.emplace(name_prefix + ".biases", std::move(biases))); } } // namespace mlx::core diff --git a/mlx/io/load.cpp b/mlx/io/load.cpp index de91ad549..b1e73ce37 100644 --- a/mlx/io/load.cpp +++ b/mlx/io/load.cpp @@ -114,16 +114,13 @@ void save(std::shared_ptr out_stream, array a) { } /** Save array to file in .npy format */ -void save(const std::string& file_, array a) { - // Open and check file - std::string file = file_; - +void save(std::string file, array a) { // Add .npy to file name if it is not there if (file.length() < 4 || file.substr(file.length() - 4, 4) != ".npy") file += ".npy"; // Serialize array - save(std::make_shared(file), a); + save(std::make_shared(std::move(file)), a); } /** Load array from reader in .npy format */ @@ -227,8 +224,8 @@ array load(std::shared_ptr in_stream, StreamOrDevice s) { } /** Load array from file in .npy format */ -array load(const std::string& file, StreamOrDevice s) { - return load(std::make_shared(file), s); +array load(std::string file, StreamOrDevice s) { + return load(std::make_shared(std::move(file)), s); } } // namespace mlx::core diff --git a/mlx/io/safetensor.cpp b/mlx/io/safetensor.cpp index 6f25aefee..69ebd46c8 100644 --- a/mlx/io/safetensor.cpp +++ b/mlx/io/safetensor.cpp @@ -60,7 +60,7 @@ std::string dtype_to_safetensor_str(Dtype t) { } } -Dtype dtype_from_safetensor_str(std::string str) { +Dtype dtype_from_safetensor_str(std::string_view str) { if (str == ST_F32) { return float32; } else if (str == ST_F16) { @@ -88,7 +88,8 @@ Dtype dtype_from_safetensor_str(std::string str) { } else if (str == ST_C64) { return complex64; } else { - throw std::runtime_error("[safetensor] unsupported dtype " + str); + throw std::runtime_error( + "[safetensor] unsupported dtype " + std::string(str)); } } @@ -129,9 +130,9 @@ SafetensorsLoad load_safetensors( } continue; } - std::string dtype = item.value().at("dtype"); - std::vector shape = item.value().at("shape"); - std::vector data_offsets = item.value().at("data_offsets"); + const std::string& dtype = item.value().at("dtype"); + const std::vector& shape = item.value().at("shape"); + const std::vector& data_offsets = item.value().at("data_offsets"); Dtype type = dtype_from_safetensor_str(dtype); auto loaded_array = array( shape, @@ -207,19 +208,17 @@ void save_safetensors( } void save_safetensors( - const std::string& file_, + std::string file, std::unordered_map a, std::unordered_map metadata /* = {} */) { - // Open and check file - std::string file = file_; - // Add .safetensors to file name if it is not there if (file.length() < 12 || file.substr(file.length() - 12, 12) != ".safetensors") file += ".safetensors"; // Serialize array - save_safetensors(std::make_shared(file), a, metadata); + save_safetensors( + std::make_shared(std::move(file)), a, metadata); } } // namespace mlx::core diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index d772c0e14..332f7088b 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -96,7 +96,7 @@ inline array matrix_norm( inline array matrix_norm( const array& a, - const std::string& ord, + std::string_view ord, const std::vector& axis, bool keepdims, StreamOrDevice s) { @@ -153,7 +153,7 @@ array norm( array norm( const array& a, - const std::string& ord, + std::string_view ord, const std::optional>& axis /* = std::nullopt */, bool keepdims /* = false */, StreamOrDevice s /* = {} */) { diff --git a/mlx/linalg.h b/mlx/linalg.h index aa46a7959..ca01aa730 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -38,13 +38,13 @@ inline array norm( } array norm( const array& a, - const std::string& ord, + std::string_view ord, const std::optional>& axis = std::nullopt, bool keepdims = false, StreamOrDevice s = {}); inline array norm( const array& a, - const std::string& ord, + std::string_view ord, int axis, bool keepdims = false, StreamOrDevice s = {}) { diff --git a/python/src/buffer.h b/python/src/buffer.h index 2118e7450..500236789 100644 --- a/python/src/buffer.h +++ b/python/src/buffer.h @@ -60,10 +60,10 @@ struct buffer_info { std::vector strides; buffer_info( - const std::string& format, + std::string format, std::vector shape_in, std::vector strides_in) - : format(format), + : format(std::move(format)), shape(std::move(shape_in)), strides(std::move(strides_in)) {}