mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-30 06:31:21 +08:00
fix a couple bugs (#952)
This commit is contained in:
parent
1a87dc5ea8
commit
741eb28443
@ -206,7 +206,7 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
|
|||||||
std::unordered_map<std::string, array> array_map;
|
std::unordered_map<std::string, array> array_map;
|
||||||
gguf_tensor tensor;
|
gguf_tensor tensor;
|
||||||
|
|
||||||
auto check_insert = [](auto inserted) {
|
auto check_insert = [](const auto& inserted) {
|
||||||
if (!inserted.second) {
|
if (!inserted.second) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[load_gguf] Duplicate parameter name " << inserted.first->second
|
msg << "[load_gguf] Duplicate parameter name " << inserted.first->second
|
||||||
@ -216,6 +216,7 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
while (gguf_get_tensor(ctx, &tensor)) {
|
while (gguf_get_tensor(ctx, &tensor)) {
|
||||||
|
std::string name(tensor.name, tensor.namelen);
|
||||||
if (tensor.type == GGUF_TYPE_Q4_0 || tensor.type == GGUF_TYPE_Q4_1 ||
|
if (tensor.type == GGUF_TYPE_Q4_0 || tensor.type == GGUF_TYPE_Q4_1 ||
|
||||||
tensor.type == GGUF_TYPE_Q8_0) {
|
tensor.type == GGUF_TYPE_Q8_0) {
|
||||||
gguf_load_quantized(array_map, tensor);
|
gguf_load_quantized(array_map, tensor);
|
||||||
@ -224,7 +225,7 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
|
|||||||
|
|
||||||
const auto& [data, dtype] = extract_tensor_data(&tensor);
|
const auto& [data, dtype] = extract_tensor_data(&tensor);
|
||||||
array loaded_array = array(data, get_shape(tensor), dtype);
|
array loaded_array = array(data, get_shape(tensor), dtype);
|
||||||
array_map.insert({name, loaded_array});
|
check_insert(array_map.insert({name, loaded_array}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return array_map;
|
return array_map;
|
||||||
|
@ -106,6 +106,7 @@ void gguf_load_quantized(
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::string name(tensor.name, tensor.namelen);
|
std::string name(tensor.name, tensor.namelen);
|
||||||
|
|
||||||
std::vector<int> shape = get_shape(tensor);
|
std::vector<int> shape = get_shape(tensor);
|
||||||
const uint64_t weights_per_block = 32;
|
const uint64_t weights_per_block = 32;
|
||||||
if (shape[shape.size() - 1] % weights_per_block != 0) {
|
if (shape[shape.size() - 1] % weights_per_block != 0) {
|
||||||
@ -136,7 +137,7 @@ void gguf_load_quantized(
|
|||||||
extract_q8_0_data(tensor, weights, scales, biases);
|
extract_q8_0_data(tensor, weights, scales, biases);
|
||||||
}
|
}
|
||||||
|
|
||||||
a.emplace(std::move(name), std::move(weights));
|
a.emplace(name, std::move(weights));
|
||||||
|
|
||||||
auto check_insert = [](const auto& inserted) {
|
auto check_insert = [](const auto& inserted) {
|
||||||
if (!inserted.second) {
|
if (!inserted.second) {
|
||||||
|
@ -186,9 +186,11 @@ class GRU(Module):
|
|||||||
n = n + r * h_proj_n
|
n = n + r * h_proj_n
|
||||||
n = mx.tanh(n)
|
n = mx.tanh(n)
|
||||||
|
|
||||||
hidden = (1 - z) * n
|
|
||||||
if hidden is not None:
|
if hidden is not None:
|
||||||
hidden = hidden + z * hidden
|
hidden = (1 - z) * n + z * hidden
|
||||||
|
else:
|
||||||
|
hidden = (1 - z) * n
|
||||||
|
|
||||||
all_hidden.append(hidden)
|
all_hidden.append(hidden)
|
||||||
|
|
||||||
return mx.stack(all_hidden, axis=-2)
|
return mx.stack(all_hidden, axis=-2)
|
||||||
|
Loading…
Reference in New Issue
Block a user