mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-10-31 10:58:07 +08:00 
			
		
		
		
	chore: enable tie_word_embeddings config for qwen2 (#544)
This commit is contained in:
		| @@ -21,6 +21,7 @@ class ModelArgs(BaseModelArgs): | |||||||
|     rope_theta: float = 1000000 |     rope_theta: float = 1000000 | ||||||
|     rope_traditional: bool = False |     rope_traditional: bool = False | ||||||
|     rope_scaling: Optional[Dict[str, Union[float, str]]] = None |     rope_scaling: Optional[Dict[str, Union[float, str]]] = None | ||||||
|  |     tie_word_embeddings: bool = False | ||||||
|  |  | ||||||
|     def __post_init__(self): |     def __post_init__(self): | ||||||
|         if self.num_key_value_heads is None: |         if self.num_key_value_heads is None: | ||||||
| @@ -176,7 +177,8 @@ class Model(nn.Module): | |||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.model_type = args.model_type |         self.model_type = args.model_type | ||||||
|         self.model = Qwen2Model(args) |         self.model = Qwen2Model(args) | ||||||
|         self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) |         if not args.tie_word_embeddings: | ||||||
|  |             self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) | ||||||
|  |  | ||||||
|     def __call__( |     def __call__( | ||||||
|         self, |         self, | ||||||
| @@ -184,7 +186,11 @@ class Model(nn.Module): | |||||||
|         cache=None, |         cache=None, | ||||||
|     ): |     ): | ||||||
|         out, cache = self.model(inputs, cache) |         out, cache = self.model(inputs, cache) | ||||||
|         return self.lm_head(out), cache |         if hasattr(self, "lm_head"): | ||||||
|  |             return self.lm_head(out), cache | ||||||
|  |  | ||||||
|  |         out = out @ self.model.embed_tokens.weight.T | ||||||
|  |         return out, cache | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def sanitize(weights): |     def sanitize(weights): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Anchen
					Anchen