From 3f337e0f0ad766379ff6325e29dc6b9b2328e960 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Mon, 22 Jul 2024 15:09:24 +0200 Subject: [PATCH] Add Mistral NeMo (fix) (#895) * fix head_dim * Update llms/mlx_lm/models/llama.py * fix kv error * formatting * Delete test.py --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/models/llama.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index 2a49ee37..b7be6a17 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -16,6 +16,7 @@ class ModelArgs(BaseModelArgs): num_attention_heads: int rms_norm_eps: float vocab_size: int + head_dim: Optional[int] = None num_key_value_heads: Optional[int] = None attention_bias: bool = False mlp_bias: bool = False @@ -45,7 +46,8 @@ class Attention(nn.Module): self.n_heads = n_heads = args.num_attention_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads - head_dim = args.hidden_size // n_heads + self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads + self.scale = head_dim**-0.5 if hasattr(args, "attention_bias"): attention_bias = args.attention_bias @@ -213,7 +215,9 @@ class Model(nn.Module): @property def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads + return ( + self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads + ) @property def n_kv_heads(self):