fix a couple bugs (#952)

This commit is contained in:
Awni Hannun 2024-04-02 12:07:41 -07:00 committed by GitHub
parent 1a87dc5ea8
commit 741eb28443
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 9 additions and 5 deletions

View File

@ -206,7 +206,7 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
std::unordered_map<std::string, array> array_map; std::unordered_map<std::string, array> array_map;
gguf_tensor tensor; gguf_tensor tensor;
auto check_insert = [](auto inserted) { auto check_insert = [](const auto& inserted) {
if (!inserted.second) { if (!inserted.second) {
std::ostringstream msg; std::ostringstream msg;
msg << "[load_gguf] Duplicate parameter name " << inserted.first->second msg << "[load_gguf] Duplicate parameter name " << inserted.first->second
@ -216,6 +216,7 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
}; };
while (gguf_get_tensor(ctx, &tensor)) { while (gguf_get_tensor(ctx, &tensor)) {
std::string name(tensor.name, tensor.namelen);
if (tensor.type == GGUF_TYPE_Q4_0 || tensor.type == GGUF_TYPE_Q4_1 || if (tensor.type == GGUF_TYPE_Q4_0 || tensor.type == GGUF_TYPE_Q4_1 ||
tensor.type == GGUF_TYPE_Q8_0) { tensor.type == GGUF_TYPE_Q8_0) {
gguf_load_quantized(array_map, tensor); gguf_load_quantized(array_map, tensor);
@ -224,7 +225,7 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
const auto& [data, dtype] = extract_tensor_data(&tensor); const auto& [data, dtype] = extract_tensor_data(&tensor);
array loaded_array = array(data, get_shape(tensor), dtype); array loaded_array = array(data, get_shape(tensor), dtype);
array_map.insert({name, loaded_array}); check_insert(array_map.insert({name, loaded_array}));
} }
} }
return array_map; return array_map;

View File

@ -106,6 +106,7 @@ void gguf_load_quantized(
} }
std::string name(tensor.name, tensor.namelen); std::string name(tensor.name, tensor.namelen);
std::vector<int> shape = get_shape(tensor); std::vector<int> shape = get_shape(tensor);
const uint64_t weights_per_block = 32; const uint64_t weights_per_block = 32;
if (shape[shape.size() - 1] % weights_per_block != 0) { if (shape[shape.size() - 1] % weights_per_block != 0) {
@ -136,7 +137,7 @@ void gguf_load_quantized(
extract_q8_0_data(tensor, weights, scales, biases); extract_q8_0_data(tensor, weights, scales, biases);
} }
a.emplace(std::move(name), std::move(weights)); a.emplace(name, std::move(weights));
auto check_insert = [](const auto& inserted) { auto check_insert = [](const auto& inserted) {
if (!inserted.second) { if (!inserted.second) {

View File

@ -186,9 +186,11 @@ class GRU(Module):
n = n + r * h_proj_n n = n + r * h_proj_n
n = mx.tanh(n) n = mx.tanh(n)
hidden = (1 - z) * n
if hidden is not None: if hidden is not None:
hidden = hidden + z * hidden hidden = (1 - z) * n + z * hidden
else:
hidden = (1 - z) * n
all_hidden.append(hidden) all_hidden.append(hidden)
return mx.stack(all_hidden, axis=-2) return mx.stack(all_hidden, axis=-2)