diff --git a/mlx/io/gguf.cpp b/mlx/io/gguf.cpp index e23fb255e..c81b39cd0 100644 --- a/mlx/io/gguf.cpp +++ b/mlx/io/gguf.cpp @@ -206,7 +206,7 @@ std::unordered_map load_arrays(gguf_ctx* ctx) { std::unordered_map array_map; gguf_tensor tensor; - auto check_insert = [](auto inserted) { + auto check_insert = [](const auto& inserted) { if (!inserted.second) { std::ostringstream msg; msg << "[load_gguf] Duplicate parameter name " << inserted.first->second @@ -216,6 +216,7 @@ std::unordered_map load_arrays(gguf_ctx* ctx) { }; while (gguf_get_tensor(ctx, &tensor)) { + std::string name(tensor.name, tensor.namelen); if (tensor.type == GGUF_TYPE_Q4_0 || tensor.type == GGUF_TYPE_Q4_1 || tensor.type == GGUF_TYPE_Q8_0) { gguf_load_quantized(array_map, tensor); @@ -224,7 +225,7 @@ std::unordered_map load_arrays(gguf_ctx* ctx) { const auto& [data, dtype] = extract_tensor_data(&tensor); array loaded_array = array(data, get_shape(tensor), dtype); - array_map.insert({name, loaded_array}); + check_insert(array_map.insert({name, loaded_array})); } } return array_map; diff --git a/mlx/io/gguf_quants.cpp b/mlx/io/gguf_quants.cpp index 86ef960d5..e0eb73ad1 100644 --- a/mlx/io/gguf_quants.cpp +++ b/mlx/io/gguf_quants.cpp @@ -106,6 +106,7 @@ void gguf_load_quantized( } std::string name(tensor.name, tensor.namelen); + std::vector shape = get_shape(tensor); const uint64_t weights_per_block = 32; if (shape[shape.size() - 1] % weights_per_block != 0) { @@ -136,7 +137,7 @@ void gguf_load_quantized( extract_q8_0_data(tensor, weights, scales, biases); } - a.emplace(std::move(name), std::move(weights)); + a.emplace(name, std::move(weights)); auto check_insert = [](const auto& inserted) { if (!inserted.second) { diff --git a/python/mlx/nn/layers/recurrent.py b/python/mlx/nn/layers/recurrent.py index d578c521c..2bb32c711 100644 --- a/python/mlx/nn/layers/recurrent.py +++ b/python/mlx/nn/layers/recurrent.py @@ -186,9 +186,11 @@ class GRU(Module): n = n + r * h_proj_n n = mx.tanh(n) - hidden = (1 - z) * n if hidden is not None: - hidden = hidden + z * hidden + hidden = (1 - z) * n + z * hidden + else: + hidden = (1 - z) * n + all_hidden.append(hidden) return mx.stack(all_hidden, axis=-2)