mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	fix a couple bugs (#952)
This commit is contained in:
		| @@ -206,7 +206,7 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) { | ||||
|   std::unordered_map<std::string, array> array_map; | ||||
|   gguf_tensor tensor; | ||||
|  | ||||
|   auto check_insert = [](auto inserted) { | ||||
|   auto check_insert = [](const auto& inserted) { | ||||
|     if (!inserted.second) { | ||||
|       std::ostringstream msg; | ||||
|       msg << "[load_gguf] Duplicate parameter name " << inserted.first->second | ||||
| @@ -216,6 +216,7 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) { | ||||
|   }; | ||||
|  | ||||
|   while (gguf_get_tensor(ctx, &tensor)) { | ||||
|     std::string name(tensor.name, tensor.namelen); | ||||
|     if (tensor.type == GGUF_TYPE_Q4_0 || tensor.type == GGUF_TYPE_Q4_1 || | ||||
|         tensor.type == GGUF_TYPE_Q8_0) { | ||||
|       gguf_load_quantized(array_map, tensor); | ||||
| @@ -224,7 +225,7 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) { | ||||
|  | ||||
|       const auto& [data, dtype] = extract_tensor_data(&tensor); | ||||
|       array loaded_array = array(data, get_shape(tensor), dtype); | ||||
|       array_map.insert({name, loaded_array}); | ||||
|       check_insert(array_map.insert({name, loaded_array})); | ||||
|     } | ||||
|   } | ||||
|   return array_map; | ||||
|   | ||||
| @@ -106,6 +106,7 @@ void gguf_load_quantized( | ||||
|   } | ||||
|  | ||||
|   std::string name(tensor.name, tensor.namelen); | ||||
|  | ||||
|   std::vector<int> shape = get_shape(tensor); | ||||
|   const uint64_t weights_per_block = 32; | ||||
|   if (shape[shape.size() - 1] % weights_per_block != 0) { | ||||
| @@ -136,7 +137,7 @@ void gguf_load_quantized( | ||||
|     extract_q8_0_data(tensor, weights, scales, biases); | ||||
|   } | ||||
|  | ||||
|   a.emplace(std::move(name), std::move(weights)); | ||||
|   a.emplace(name, std::move(weights)); | ||||
|  | ||||
|   auto check_insert = [](const auto& inserted) { | ||||
|     if (!inserted.second) { | ||||
|   | ||||
| @@ -186,9 +186,11 @@ class GRU(Module): | ||||
|                 n = n + r * h_proj_n | ||||
|             n = mx.tanh(n) | ||||
|  | ||||
|             hidden = (1 - z) * n | ||||
|             if hidden is not None: | ||||
|                 hidden = hidden + z * hidden | ||||
|                 hidden = (1 - z) * n + z * hidden | ||||
|             else: | ||||
|                 hidden = (1 - z) * n | ||||
|  | ||||
|             all_hidden.append(hidden) | ||||
|  | ||||
|         return mx.stack(all_hidden, axis=-2) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun