mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-11-04 05:28:11 +08:00 
			
		
		
		
	Add Mistral NeMo (fix) (#895)
* fix head_dim * Update llms/mlx_lm/models/llama.py * fix kv error * formatting * Delete test.py --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
		@@ -16,6 +16,7 @@ class ModelArgs(BaseModelArgs):
 | 
			
		||||
    num_attention_heads: int
 | 
			
		||||
    rms_norm_eps: float
 | 
			
		||||
    vocab_size: int
 | 
			
		||||
    head_dim: Optional[int] = None
 | 
			
		||||
    num_key_value_heads: Optional[int] = None
 | 
			
		||||
    attention_bias: bool = False
 | 
			
		||||
    mlp_bias: bool = False
 | 
			
		||||
@@ -45,7 +46,8 @@ class Attention(nn.Module):
 | 
			
		||||
        self.n_heads = n_heads = args.num_attention_heads
 | 
			
		||||
        self.n_kv_heads = n_kv_heads = args.num_key_value_heads
 | 
			
		||||
 | 
			
		||||
        head_dim = args.hidden_size // n_heads
 | 
			
		||||
        self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
 | 
			
		||||
 | 
			
		||||
        self.scale = head_dim**-0.5
 | 
			
		||||
        if hasattr(args, "attention_bias"):
 | 
			
		||||
            attention_bias = args.attention_bias
 | 
			
		||||
@@ -213,7 +215,9 @@ class Model(nn.Module):
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def head_dim(self):
 | 
			
		||||
        return self.args.hidden_size // self.args.num_attention_heads
 | 
			
		||||
        return (
 | 
			
		||||
            self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def n_kv_heads(self):
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user